Skip to content

Commit

Permalink
Update typehints (#448)
Browse files Browse the repository at this point in the history
So you get linter warnings if you try to do a name. It doesn't actually
alter the behavior, so if we choose to extend the number of types, it
would just be a linting issue rather than a runtime issue
  • Loading branch information
hinthornw authored Feb 15, 2024
1 parent 286a5a6 commit 01211bc
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
6 changes: 5 additions & 1 deletion python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -117,6 +118,9 @@ def _is_langchain_hosted(url: str) -> bool:


ID_TYPE = Union[uuid.UUID, str]
RUN_TYPE_T = Literal[
"tool", "chain", "llm", "retriever", "embedding", "prompt", "parser"
]


def _default_retry_config() -> Retry:
Expand Down Expand Up @@ -922,7 +926,7 @@ def create_run(
self,
name: str,
inputs: Dict[str, Any],
run_type: str,
run_type: RUN_TYPE_T,
*,
project_name: Optional[str] = None,
revision_id: Optional[str] = None,
Expand Down
21 changes: 11 additions & 10 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
runtime_checkable,
)

from langsmith import client, run_trees, utils
from langsmith import client as ls_client
from langsmith import run_trees, utils

if TYPE_CHECKING:
from langchain.schema.runnable import Runnable
Expand Down Expand Up @@ -94,14 +95,14 @@ def _get_inputs(
class LangSmithExtra(TypedDict, total=False):
"""Any additional info to be injected into the run dynamically."""

reference_example_id: Optional[client.ID_TYPE]
reference_example_id: Optional[ls_client.ID_TYPE]
run_extra: Optional[Dict]
run_tree: Optional[run_trees.RunTree]
project_name: Optional[str]
metadata: Optional[Dict[str, Any]]
tags: Optional[List[str]]
run_id: Optional[client.ID_TYPE]
client: Optional[client.Client]
run_id: Optional[ls_client.ID_TYPE]
client: Optional[ls_client.Client]


class _TraceableContainer(TypedDict, total=False):
Expand Down Expand Up @@ -141,13 +142,13 @@ def _collect_extra(extra_outer: dict, langsmith_extra: LangSmithExtra) -> dict:

def _setup_run(
func: Callable,
run_type: str,
run_type: ls_client.RUN_TYPE_T,
extra_outer: dict,
langsmith_extra: Optional[LangSmithExtra] = None,
name: Optional[str] = None,
metadata: Optional[Mapping[str, Any]] = None,
tags: Optional[List[str]] = None,
client: Optional[client.Client] = None,
client: Optional[ls_client.Client] = None,
args: Any = None,
kwargs: Any = None,
) -> _TraceableContainer:
Expand Down Expand Up @@ -272,12 +273,12 @@ def traceable(

@overload
def traceable(
run_type: str = "chain",
run_type: ls_client.RUN_TYPE_T = "chain",
*,
name: Optional[str] = None,
metadata: Optional[Mapping[str, Any]] = None,
tags: Optional[List[str]] = None,
client: Optional[client.Client] = None,
client: Optional[ls_client.Client] = None,
extra: Optional[Dict] = None,
reduce_fn: Optional[Callable] = None,
) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]:
Expand Down Expand Up @@ -305,7 +306,7 @@ def traceable(
called, and the run itself will be stuck in a pending state.
"""
run_type = (
run_type: ls_client.RUN_TYPE_T = (
args[0]
if args and isinstance(args[0], str)
else (kwargs.get("run_type") or "chain")
Expand Down Expand Up @@ -582,7 +583,7 @@ def generator_wrapper(
@contextlib.contextmanager
def trace(
name: str,
run_type: str,
run_type: ls_client.RUN_TYPE_T = "chain",
*,
inputs: Optional[Dict] = None,
extra: Optional[Dict] = None,
Expand Down
4 changes: 2 additions & 2 deletions python/langsmith/run_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pydantic import Field, root_validator, validator

from langsmith import utils
from langsmith.client import ID_TYPE, Client
from langsmith.client import ID_TYPE, RUN_TYPE_T, Client
from langsmith.schemas import RunBase

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,7 +112,7 @@ def end(
def create_child(
self,
name: str,
run_type: str,
run_type: RUN_TYPE_T = "chain",
*,
run_id: Optional[ID_TYPE] = None,
serialized: Optional[Dict] = None,
Expand Down
4 changes: 2 additions & 2 deletions python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def test_traceable_warning() -> None:
with warnings.catch_warnings(record=True) as warning_records:
warnings.simplefilter("always")

@traceable(run_type="invalid_run_type")
@traceable(run_type="invalid_run_type") # type: ignore
def my_function() -> None:
pass

Expand All @@ -373,7 +373,7 @@ def test_traceable_wrong_run_type_pos_arg() -> None:
with warnings.catch_warnings(record=True) as warning_records:
warnings.simplefilter("always")

@traceable("my_run_type")
@traceable("my_run_type") # type: ignore
def my_function() -> None:
pass

Expand Down

0 comments on commit 01211bc

Please sign in to comment.