Skip to content

Commit

Permalink
Add Annotation.Root to make it easier to access State, Update and Nod…
Browse files Browse the repository at this point in the history
…e types
  • Loading branch information
nfcampos committed Aug 12, 2024
1 parent 802111d commit 2d4b9b8
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 86 deletions.
104 changes: 104 additions & 0 deletions langgraph/src/graph/annotation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import { RunnableLike } from "@langchain/core/runnables";
import { BaseChannel } from "../channels/base.js";
import { BinaryOperator, BinaryOperatorAggregate } from "../channels/binop.js";
import { LastValue } from "../channels/last_value.js";

export type SingleReducer<ValueType, UpdateType = ValueType> =
| {
reducer: BinaryOperator<ValueType, UpdateType>;
default?: () => ValueType;
}
| {
/**
* @deprecated Use `reducer` instead
*/
value: BinaryOperator<ValueType, UpdateType>;
default?: () => ValueType;
}
| null;

export interface StateDefinition {
[key: string]: BaseChannel | (() => BaseChannel);
}

type ExtractValueType<C> = C extends BaseChannel
? C["ValueType"]
: C extends () => BaseChannel
? ReturnType<C>["ValueType"]
: never;

type ExtractUpdateType<C> = C extends BaseChannel
? C["UpdateType"]
: C extends () => BaseChannel
? ReturnType<C>["UpdateType"]
: never;

export type StateType<SD extends StateDefinition> = {
[key in keyof SD]: ExtractValueType<SD[key]>;
};

export type UpdateType<SD extends StateDefinition> = {
[key in keyof SD]?: ExtractUpdateType<SD[key]>;
};

export type NodeType<SD extends StateDefinition> = RunnableLike<
StateType<SD>,
UpdateType<SD>
>;

export class AnnotationRoot<SD extends StateDefinition> {
lc_graph_name = "AnnotationRoot";

State: StateType<SD>;

Update: UpdateType<SD>;

Node: NodeType<SD>;

spec: SD;

constructor(s: SD) {
this.spec = s;
}
}

export function Annotation<ValueType>(): LastValue<ValueType>;

export function Annotation<ValueType, UpdateType = ValueType>(
annotation: SingleReducer<ValueType, UpdateType>
): BinaryOperatorAggregate<ValueType, UpdateType>;

export function Annotation<ValueType, UpdateType = ValueType>(
annotation?: SingleReducer<ValueType, UpdateType>
): BaseChannel<ValueType, UpdateType> {
if (annotation) {
return getChannel<ValueType, UpdateType>(annotation);
} else {
// @ts-expect-error - Annotation without reducer
return new LastValue<ValueType>();
}
}
Annotation.Root = <S extends StateDefinition>(sd: S) => new AnnotationRoot(sd);

export function getChannel<V, U = V>(
reducer: SingleReducer<V, U>
): BaseChannel<V, U> {
if (
typeof reducer === "object" &&
reducer &&
"reducer" in reducer &&
reducer.reducer
) {
return new BinaryOperatorAggregate(reducer.reducer, reducer.default);
}
if (
typeof reducer === "object" &&
reducer &&
"value" in reducer &&
reducer.value
) {
return new BinaryOperatorAggregate(reducer.value, reducer.default);
}
// @ts-expect-error - Annotation without reducer
return new LastValue<V>();
}
4 changes: 1 addition & 3 deletions langgraph/src/graph/index.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
export { Annotation, type StateType, type UpdateType } from "./annotation.js";
export { END, START, Graph } from "./graph.js";
export {
type StateGraphArgs,
StateGraph,
type CompiledStateGraph,
Annotation,
type StateType,
type UpdateType,
} from "./state.js";
export { MessageGraph, messagesStateReducer } from "./message.js";
2 changes: 1 addition & 1 deletion langgraph/src/graph/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type Messages =
| BaseMessageLike;

export function messagesStateReducer(
left: Messages,
left: BaseMessage[],
right: Messages
): BaseMessage[] {
const leftArray = Array.isArray(left) ? left : [left];
Expand Down
104 changes: 23 additions & 81 deletions langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import {
RunnableLike,
} from "@langchain/core/runnables";
import { BaseChannel } from "../channels/base.js";
import { BinaryOperator, BinaryOperatorAggregate } from "../channels/binop.js";
import { END, CompiledGraph, Graph, START, Branch } from "./graph.js";
import { LastValue } from "../channels/last_value.js";
import {
ChannelWrite,
ChannelWriteEntry,
Expand All @@ -22,64 +20,17 @@ import { RunnableCallable } from "../utils.js";
import { All } from "../pregel/types.js";
import { TAG_HIDDEN } from "../constants.js";
import { InvalidUpdateError } from "../errors.js";
import {
AnnotationRoot,
getChannel,
SingleReducer,
StateDefinition,
StateType,
UpdateType,
} from "./annotation.js";

const ROOT = "__root__";

export function Annotation<ValueType>(): LastValue<ValueType>;

export function Annotation<ValueType, UpdateType = ValueType>(
annotation: SingleReducer<ValueType, UpdateType>
): BinaryOperatorAggregate<ValueType, UpdateType>;

export function Annotation<ValueType, UpdateType = ValueType>(
annotation?: SingleReducer<ValueType, UpdateType>
): BaseChannel<ValueType, UpdateType> {
if (annotation) {
return getChannel<ValueType, UpdateType>(annotation);
} else {
// @ts-expect-error - Annotation without reducer
return new LastValue<ValueType>();
}
}

interface StateDefinition {
[key: string]: BaseChannel | (() => BaseChannel);
}

type ExtractValueType<C> = C extends BaseChannel
? C["ValueType"]
: C extends () => BaseChannel
? ReturnType<C>["ValueType"]
: never;

type ExtractUpdateType<C> = C extends BaseChannel
? C["UpdateType"]
: C extends () => BaseChannel
? ReturnType<C>["UpdateType"]
: never;

export type StateType<S extends StateDefinition> = {
[key in keyof S]: ExtractValueType<S[key]>;
};

export type UpdateType<S extends StateDefinition> = {
[key in keyof S]?: ExtractUpdateType<S[key]>;
};

type SingleReducer<ValueType, UpdateType = ValueType> =
| {
reducer: BinaryOperator<ValueType, UpdateType>;
default?: () => ValueType;
}
| {
/**
* @deprecated Use `reducer` instead
*/
value: BinaryOperator<ValueType, UpdateType>;
default?: () => ValueType;
}
| null;

export type ChannelReducers<Channels extends object> = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[K in keyof Channels]: SingleReducer<Channels[K], any>;
Expand All @@ -106,13 +57,14 @@ export class StateGraph<

constructor(
fields: SD extends StateDefinition
? SD | StateGraphArgs<S>
? SD | AnnotationRoot<SD> | StateGraphArgs<S>
: StateGraphArgs<S>
) {
super();
if (isStateDefinition(fields)) {
if (isStateDefinition(fields) || isAnnotationRoot(fields)) {
const spec = isAnnotationRoot(fields) ? fields.spec : fields;
this.channels = {};
for (const [key, val] of Object.entries(fields)) {
for (const [key, val] of Object.entries(spec)) {
if (typeof val === "function") {
this.channels[key] = val();
} else {
Expand Down Expand Up @@ -261,27 +213,6 @@ function _getChannels<Channels extends Record<string, unknown> | unknown>(
return channels;
}

function getChannel<V, U = V>(reducer: SingleReducer<V, U>): BaseChannel<V, U> {
if (
typeof reducer === "object" &&
reducer &&
"reducer" in reducer &&
reducer.reducer
) {
return new BinaryOperatorAggregate(reducer.reducer, reducer.default);
}
if (
typeof reducer === "object" &&
reducer &&
"value" in reducer &&
reducer.value
) {
return new BinaryOperatorAggregate(reducer.value, reducer.default);
}
// @ts-expect-error - Annotation without reducer
return new LastValue<V>();
}

export class CompiledStateGraph<
S,
U,
Expand Down Expand Up @@ -443,3 +374,14 @@ function isStateDefinition(obj: unknown): obj is StateDefinition {
Object.values(obj).every((v) => typeof v === "function" || isBaseChannel(v))
);
}

function isAnnotationRoot<SD extends StateDefinition>(
obj: unknown | AnnotationRoot<SD>
): obj is AnnotationRoot<SD> {
return (
typeof obj === "object" &&
obj !== null &&
"lc_graph_name" in obj &&
obj.lc_graph_name === "AnnotationRoot"
);
}
3 changes: 2 additions & 1 deletion langgraph/src/tests/graph.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { describe, it, expect } from "@jest/globals";
import { Annotation, StateGraph } from "../graph/state.js";
import { StateGraph } from "../graph/state.js";
import { END, START } from "../web.js";
import { Annotation } from "../graph/annotation.js";

describe("State", () => {
it("should validate a new node key correctly ", () => {
Expand Down

0 comments on commit 2d4b9b8

Please sign in to comment.