Skip to content

Commit

Permalink
Add input/output schemas to runnables (#11063)
Browse files Browse the repository at this point in the history
This adds `input_schema` and `output_schema` properties to all
runnables, which are Pydantic models for the input and output types
respectively. These are inferred from the structure of the Runnable as
much as possible, the only manual typing needed is
- optionally add type hints to lambdas (which get translated to
input/output schemas)
- optionally add type hint to RunnablePassthrough

These schemas can then be used to create JSON Schema descriptions of
input and output types, see the tests

- [x] Ensure no InputType and OutputType in our classes use abstract
base classes (replace with union of subclasses)
- [x] Implement in BaseChain and LLMChain
- [x] Implement in RunnableBranch
- [x] Implement in RunnableBinding, RunnableMap, RunnablePassthrough,
RunnableEach, RunnableRouter
- [x] Implement in LLM, Prompt, Chat Model, Output Parser, Retriever
- [x] Implement in RunnableLambda from function signature
- [x] Implement in Tool

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
nfcampos authored Sep 28, 2023
1 parent b05bb9e commit cfa2203
Show file tree
Hide file tree
Showing 19 changed files with 2,211 additions and 86 deletions.
24 changes: 22 additions & 2 deletions libs/langchain/langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, Union

import yaml

Expand All @@ -22,7 +22,13 @@
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field, root_validator, validator
from langchain.pydantic_v1 import (
BaseModel,
Field,
create_model,
root_validator,
validator,
)
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig

Expand Down Expand Up @@ -56,6 +62,20 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`.
"""

@property
def input_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys}
)

@property
def output_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
)

def invoke(
self,
input: Dict[str, Any],
Expand Down
29 changes: 27 additions & 2 deletions libs/langchain/langchain/chains/combine_documents/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Base interface for chains combining documents."""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type

from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import Field
from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter


Expand All @@ -28,6 +28,20 @@ class BaseCombineDocumentsChain(Chain, ABC):
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:

@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
)

@property
def output_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload]
)

@property
def input_keys(self) -> List[str]:
"""Expect input key.
Expand Down Expand Up @@ -153,6 +167,17 @@ def output_keys(self) -> List[str]:
"""
return self.combine_docs_chain.output_keys

@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload]
)

@property
def output_schema(self) -> Type[BaseModel]:
return self.combine_docs_chain.output_schema

def _call(
self,
inputs: Dict[str, str],
Expand Down
15 changes: 14 additions & 1 deletion libs/langchain/langchain/chains/combine_documents/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import Extra, root_validator
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator


class MapReduceDocumentsChain(BaseCombineDocumentsChain):
Expand Down Expand Up @@ -98,6 +98,19 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return_intermediate_steps: bool = False
"""Return the results of the map steps in the output."""

@property
def output_schema(self) -> type[BaseModel]:
if self.return_intermediate_steps:
return create_model(
"MapReduceDocumentsOutput",
**{
self.output_key: (str, None),
"intermediate_steps": (List[str], None),
}, # type: ignore[call-overload]
)

return super().output_schema

@property
def output_keys(self) -> List[str]:
"""Expect input key.
Expand Down
14 changes: 13 additions & 1 deletion libs/langchain/langchain/chains/combine_documents/map_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.pydantic_v1 import Extra, root_validator
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator


class MapRerankDocumentsChain(BaseCombineDocumentsChain):
Expand Down Expand Up @@ -77,6 +77,18 @@ class Config:
extra = Extra.forbid
arbitrary_types_allowed = True

@property
def output_schema(self) -> type[BaseModel]:
schema: Dict[str, Any] = {
self.output_key: (str, None),
}
if self.return_intermediate_steps:
schema["intermediate_steps"] = (List[str], None)
if self.metadata_keys:
schema.update({key: (Any, None) for key in self.metadata_keys})

return create_model("MapRerankOutput", **schema)

@property
def output_keys(self) -> List[str]:
"""Expect input key.
Expand Down
17 changes: 17 additions & 0 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Optional,
Sequence,
Union,
cast,
)

Expand All @@ -37,9 +38,14 @@
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
Expand Down Expand Up @@ -107,6 +113,17 @@ class Config:

# --- Runnable methods ---

@property
def OutputType(self) -> Any:
"""Get the input type for this runnable."""
return Union[
HumanMessageChunk,
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
SystemMessageChunk,
]

def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
Expand Down
35 changes: 2 additions & 33 deletions libs/langchain/langchain/chat_models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
Expand All @@ -53,39 +55,6 @@ class ChatLiteLLMException(Exception):
"""Error with the `LiteLLM I/O` library"""


def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""Truncates text at the earliest stop token found."""
if stop is None:
return text

for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text


class FunctionMessage(BaseMessage):
"""Message for passing the result of executing a function back to a model."""

name: str
"""The name of the function that was executed."""

@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function"


class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""Message Chunk for passing the result of executing a function back to a model."""

pass


def _create_retry_decorator(
llm: ChatLiteLLM,
run_manager: Optional[
Expand Down
5 changes: 5 additions & 0 deletions libs/langchain/langchain/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def set_verbose(cls, verbose: Optional[bool]) -> bool:

# --- Runnable methods ---

@property
def OutputType(self) -> Type[str]:
"""Get the input type for this runnable."""
return str

def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
Expand Down
12 changes: 10 additions & 2 deletions libs/langchain/langchain/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from langchain.schema.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
Expand Down Expand Up @@ -280,7 +281,7 @@ class ChatPromptValue(PromptValue):
A type of a prompt value that is built from messages.
"""

messages: List[BaseMessage]
messages: Sequence[BaseMessage]
"""List of messages."""

def to_string(self) -> str:
Expand All @@ -289,7 +290,14 @@ def to_string(self) -> str:

def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of messages."""
return self.messages
return list(self.messages)


class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""

messages: Sequence[AnyMessage]


class BaseChatPromptTemplate(BasePromptTemplate, ABC):
Expand Down
19 changes: 18 additions & 1 deletion libs/langchain/langchain/schema/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
Union,
)

from typing_extensions import TypeAlias

from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable
Expand Down Expand Up @@ -70,6 +72,21 @@ class BaseLanguageModel(
Each of these has an equivalent asynchronous method.
"""

@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete

# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
]

@abstractmethod
def generate_prompt(
self,
Expand Down
Loading

0 comments on commit cfa2203

Please sign in to comment.