diff --git a/python/langsmith/evaluation/_arunner.py b/python/langsmith/evaluation/_arunner.py index 729166add..6ba6095f8 100644 --- a/python/langsmith/evaluation/_arunner.py +++ b/python/langsmith/evaluation/_arunner.py @@ -5,7 +5,6 @@ import asyncio import concurrent.futures as cf import datetime -import inspect import logging import pathlib import uuid @@ -41,6 +40,7 @@ _ExperimentManagerMixin, _extract_feedback_keys, _ForwardResults, + _include_attachments, _is_langchain_runnable, _load_examples_map, _load_experiment, @@ -1058,49 +1058,6 @@ def _get_run(r: run_trees.RunTree) -> None: ) -def _include_attachments( - target: Union[ATARGET_T, Iterable[schemas.Run], AsyncIterable[dict], Runnable], -) -> bool: - """Whether the target function accepts attachments.""" - if _is_langchain_runnable(target) or not callable(target): - return False - # Check function signature - sig = inspect.signature(target) - params = list(sig.parameters.values()) - positional_params = [ - p - for p in params - if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) - and p.default is p.empty - ] - - if len(positional_params) == 0: - raise ValueError( - "Target function must accept at least one positional argument (inputs)" - ) - elif len(positional_params) > 2: - raise ValueError( - "Target function must accept at most two positional " - "arguments (inputs, attachments)" - ) - elif len(positional_params) == 2: - mismatches = [] - for i, (p, expected) in enumerate( - zip(positional_params, ("inputs", "attachments")) - ): - if p.name != expected: - mismatches.append((i, p.name)) - - if mismatches: - raise ValueError( - "When target function has two positional arguments, they must be named " - "'inputs' and 'attachments', respectively. Received: " - + ",".join(f"'{p}' at index {i}" for i, p in mismatches) - ) - - return len(positional_params) == 2 - - def _ensure_async_traceable( target: ATARGET_T, ) -> rh.SupportsLangsmithExtra[[dict], Awaitable]: diff --git a/python/langsmith/evaluation/_runner.py b/python/langsmith/evaluation/_runner.py index 199d8fa22..bc24585d0 100644 --- a/python/langsmith/evaluation/_runner.py +++ b/python/langsmith/evaluation/_runner.py @@ -1913,9 +1913,7 @@ def _ensure_traceable( return fn -def _include_attachments( - target: Union[TARGET_T, Iterable[schemas.Run], Runnable], -) -> bool: +def _include_attachments(target: Any) -> bool: """Whether the target function accepts attachments.""" if _is_langchain_runnable(target) or not callable(target): return False @@ -1923,37 +1921,39 @@ def _include_attachments( sig = inspect.signature(target) params = list(sig.parameters.values()) positional_params = [ - p - for p in params - if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) - and p.default is p.empty + p for p in params if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) ] + positional_no_default = [p for p in positional_params if p.default is p.empty] if len(positional_params) == 0: raise ValueError( - "Target function must accept at least one positional argument (inputs)" + "Target function must accept at least one positional argument (inputs)." ) - elif len(positional_params) > 2: + elif len(positional_no_default) > 2: raise ValueError( - "Target function must accept at most two positional " - "arguments (inputs, attachments)" + "Target function must accept at most two " + "arguments without default values: (inputs, attachments)." ) - elif len(positional_params) == 2: + else: mismatches = [] + num_args = 0 for i, (p, expected) in enumerate( zip(positional_params, ("inputs", "attachments")) ): if p.name != expected: mismatches.append((i, p.name)) + else: + num_args += 1 if mismatches: - raise ValueError( - "When target function has two positional arguments, they must be named " - "'inputs' and 'attachments', respectively. Received: " - + ",".join(f"'{p}' at index {i}" for i, p in mismatches) + msg = ( + "Target function is expected to have a first positional argument " + "'inputs' and optionally a second positional argument 'attachments'. " + "Received: " + ", ".join(f"'{p}' at index {i}" for i, p in mismatches) ) + raise ValueError(msg) - return len(positional_params) == 2 + return num_args == 2 def _resolve_experiment(