diff --git a/python/langsmith/evaluation/_arunner.py b/python/langsmith/evaluation/_arunner.py index 7c791ac09..a2c3b2705 100644 --- a/python/langsmith/evaluation/_arunner.py +++ b/python/langsmith/evaluation/_arunner.py @@ -40,6 +40,7 @@ _ExperimentManagerMixin, _extract_feedback_keys, _ForwardResults, + _is_langchain_runnable, _load_examples_map, _load_experiment, _load_tqdm, @@ -379,8 +380,10 @@ async def _aevaluate( blocking: bool = True, experiment: Optional[Union[schemas.TracerSession, str, uuid.UUID]] = None, ) -> AsyncExperimentResults: - is_async_target = asyncio.iscoroutinefunction(target) or ( - hasattr(target, "__aiter__") and asyncio.iscoroutine(target.__aiter__()) + is_async_target = ( + asyncio.iscoroutinefunction(target) + or (hasattr(target, "__aiter__") and asyncio.iscoroutine(target.__aiter__())) + or _is_langchain_runnable(target) ) client = client or rt.get_cached_client() runs = None if is_async_target else cast(Iterable[schemas.Run], target) @@ -940,7 +943,7 @@ def _get_run(r: run_trees.RunTree) -> None: def _ensure_async_traceable( target: ATARGET_T, ) -> rh.SupportsLangsmithExtra[[dict], Awaitable]: - if not asyncio.iscoroutinefunction(target): + if not asyncio.iscoroutinefunction(target) and not _is_langchain_runnable(target): if callable(target): raise ValueError( "Target must be an async function. For sync functions, use evaluate." @@ -961,7 +964,10 @@ def _ensure_async_traceable( ) if rh.is_traceable_function(target): return target # type: ignore - return rh.traceable(name="AsyncTarget")(target) + else: + if _is_langchain_runnable(target): + target = target.ainvoke # type: ignore[attr-defined] + return rh.traceable(name="AsyncTarget")(target) def _aresolve_data( diff --git a/python/langsmith/evaluation/_runner.py b/python/langsmith/evaluation/_runner.py index 111986b76..ceb5d8561 100644 --- a/python/langsmith/evaluation/_runner.py +++ b/python/langsmith/evaluation/_runner.py @@ -58,6 +58,7 @@ if TYPE_CHECKING: import pandas as pd + from langchain_core.runnables import Runnable DataFrame = pd.DataFrame else: @@ -96,7 +97,7 @@ def evaluate( - target: TARGET_T, + target: Union[TARGET_T, Runnable], /, data: DATA_T, evaluators: Optional[Sequence[EVALUATOR_T]] = None, @@ -878,12 +879,12 @@ def _print_comparative_experiment_start( ) -def _is_callable(target: Union[TARGET_T, Iterable[schemas.Run]]) -> bool: - return callable(target) or (hasattr(target, "invoke") and callable(target.invoke)) +def _is_callable(target: Union[TARGET_T, Iterable[schemas.Run], Runnable]) -> bool: + return callable(target) or _is_langchain_runnable(target) def _evaluate( - target: Union[TARGET_T, Iterable[schemas.Run]], + target: Union[TARGET_T, Iterable[schemas.Run], Runnable], /, data: DATA_T, evaluators: Optional[Sequence[EVALUATOR_T]] = None, @@ -1664,12 +1665,13 @@ def _resolve_data( def _ensure_traceable( - target: TARGET_T | rh.SupportsLangsmithExtra[[dict], dict], + target: TARGET_T | rh.SupportsLangsmithExtra[[dict], dict] | Runnable, ) -> rh.SupportsLangsmithExtra[[dict], dict]: """Ensure the target function is traceable.""" - if not callable(target): + if not _is_callable(target): raise ValueError( - "Target must be a callable function. For example:\n\n" + "Target must be a callable function or a langchain/langgraph object. For " + "example:\n\n" "def predict(inputs: dict) -> dict:\n" " # do work, like chain.invoke(inputs)\n" " return {...}\n\n" @@ -1679,9 +1681,11 @@ def _ensure_traceable( ")" ) if rh.is_traceable_function(target): - fn = target + fn: rh.SupportsLangsmithExtra[[dict], dict] = target else: - fn = rh.traceable(name="Target")(target) + if _is_langchain_runnable(target): + target = target.invoke # type: ignore[union-attr] + fn = rh.traceable(name="Target")(cast(Callable, target)) return fn @@ -1709,9 +1713,8 @@ def _resolve_experiment( return experiment_, runs # If we have runs, that means the experiment was already started. if runs is not None: - if runs is not None: - runs_, runs = itertools.tee(runs) - first_run = next(runs_) + runs_, runs = itertools.tee(runs) + first_run = next(runs_) experiment_ = client.read_project(project_id=first_run.session_id) if not experiment_.name: raise ValueError("Experiment name not found for provided runs.") @@ -1923,3 +1926,17 @@ def _flatten_experiment_results( } for x in results[start:end] ] + + +@functools.lru_cache(maxsize=1) +def _import_langchain_runnable() -> Optional[type]: + try: + from langchain_core.runnables import Runnable + + return Runnable + except ImportError: + return None + + +def _is_langchain_runnable(o: Any) -> bool: + return bool((Runnable := _import_langchain_runnable()) and isinstance(o, Runnable)) diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index eaa838192..7510b75ee 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -151,9 +151,7 @@ def tracing_context( get_run_tree_context = get_current_run_tree -def is_traceable_function( - func: Callable[P, R], -) -> TypeGuard[SupportsLangsmithExtra[P, R]]: +def is_traceable_function(func: Any) -> TypeGuard[SupportsLangsmithExtra[P, R]]: """Check if a function is @traceable decorated.""" return ( _is_traceable_function(func) @@ -1445,7 +1443,7 @@ def _handle_container_end( LOGGER.warning(f"Unable to process trace outputs: {repr(e)}") -def _is_traceable_function(func: Callable) -> bool: +def _is_traceable_function(func: Any) -> bool: return getattr(func, "__langsmith_traceable__", False) diff --git a/python/pyproject.toml b/python/pyproject.toml index 1966dc3f3..81645c912 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.142" +version = "0.1.143" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/unit_tests/evaluation/test_runner.py b/python/tests/unit_tests/evaluation/test_runner.py index d20960d3e..943f55b21 100644 --- a/python/tests/unit_tests/evaluation/test_runner.py +++ b/python/tests/unit_tests/evaluation/test_runner.py @@ -80,7 +80,9 @@ def request(self, verb: str, endpoint: str, *args, **kwargs): res = MagicMock() res.json.return_value = { "runs": [ - r for r in self.runs.values() if "reference_example_id" in r + r + for r in self.runs.values() + if r["trace_id"] == r["id"] and r.get("reference_example_id") ] } return res @@ -120,7 +122,8 @@ def _wait_until(condition: Callable, timeout: int = 8): @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher") @pytest.mark.parametrize("blocking", [False, True]) -def test_evaluate_results(blocking: bool) -> None: +@pytest.mark.parametrize("as_runnable", [False, True]) +def test_evaluate_results(blocking: bool, as_runnable: bool) -> None: session = mock.Mock() ds_name = "my-dataset" ds_id = "00886375-eb2a-4038-9032-efff60309896" @@ -180,6 +183,15 @@ def predict(inputs: dict) -> dict: ordering_of_stuff.append("predict") return {"output": inputs["in"] + 1} + if as_runnable: + try: + from langchain_core.runnables import RunnableLambda + except ImportError: + pytest.skip("langchain-core not installed.") + return + else: + predict = RunnableLambda(predict) + def score_value_first(run, example): ordering_of_stuff.append("evaluate") return {"score": 0.3} @@ -263,26 +275,24 @@ async def my_other_func(inputs: dict, other_val: int): with pytest.raises(ValueError, match=match): evaluate(functools.partial(my_other_func, other_val=3), data="foo") + if sys.version_info < (3, 10): + return try: from langchain_core.runnables import RunnableLambda except ImportError: pytest.skip("langchain-core not installed.") - - @RunnableLambda - def foo(inputs: dict): - return "bar" - - with pytest.raises(ValueError, match=match): - evaluate(foo.ainvoke, data="foo") - if sys.version_info < (3, 10): return with pytest.raises(ValueError, match=match): - evaluate(functools.partial(foo.ainvoke, inputs={"foo": "bar"}), data="foo") + evaluate( + functools.partial(RunnableLambda(my_func).ainvoke, inputs={"foo": "bar"}), + data="foo", + ) @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher") @pytest.mark.parametrize("blocking", [False, True]) -async def test_aevaluate_results(blocking: bool) -> None: +@pytest.mark.parametrize("as_runnable", [False, True]) +async def test_aevaluate_results(blocking: bool, as_runnable: bool) -> None: session = mock.Mock() ds_name = "my-dataset" ds_id = "00886375-eb2a-4038-9032-efff60309896" @@ -343,6 +353,15 @@ async def predict(inputs: dict) -> dict: ordering_of_stuff.append("predict") return {"output": inputs["in"] + 1} + if as_runnable: + try: + from langchain_core.runnables import RunnableLambda + except ImportError: + pytest.skip("langchain-core not installed.") + return + else: + predict = RunnableLambda(predict) + async def score_value_first(run, example): ordering_of_stuff.append("evaluate") return {"score": 0.3}