From d79fcf164da8c0c02dc535d4d48acd90bc72ebf5 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 23 Feb 2024 16:01:16 -0800 Subject: [PATCH] Wfh/strip not given (#472) --- python/Makefile | 2 +- python/langsmith/run_helpers.py | 10 +++++ python/langsmith/wrappers/_openai.py | 47 +++++++++++++++++++-- python/tests/unit_tests/test_run_helpers.py | 7 ++- 4 files changed, 61 insertions(+), 5 deletions(-) diff --git a/python/Makefile b/python/Makefile index d6eaebfd4..5e448866b 100644 --- a/python/Makefile +++ b/python/Makefile @@ -1,7 +1,7 @@ .PHONY: tests lint format tests: - poetry run pytest tests/unit_tests + poetry run pytest -n auto --durations=10 tests/unit_tests tests_watch: poetry run ptw --now . -- -vv -x tests/unit_tests diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 35ce222c7..bdbcea871 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -129,6 +129,7 @@ class _ContainerInput(TypedDict, total=False): reduce_fn: Optional[Callable] project_name: Optional[str] run_type: ls_client.RUN_TYPE_T + process_inputs: Optional[Callable[[dict], dict]] def _container_end( @@ -207,6 +208,12 @@ def _setup_run( except TypeError as e: logger.debug(f"Failed to infer inputs for {name_}: {e}") inputs = {"args": args, "kwargs": kwargs} + process_inputs = container_input.get("process_inputs") + if process_inputs: + try: + inputs = process_inputs(inputs) + except Exception as e: + logger.error(f"Failed to filter inputs for {name_}: {e}") outer_tags = _TAGS.get() tags_ = (langsmith_extra.get("tags") or []) + (outer_tags or []) _TAGS.set(tags_) @@ -325,6 +332,7 @@ def traceable( client: Optional[ls_client.Client] = None, reduce_fn: Optional[Callable] = None, project_name: Optional[str] = None, + process_inputs: Optional[Callable[[dict], dict]] = None, ) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]: ... @@ -350,6 +358,7 @@ def traceable( called, and the run itself will be stuck in a pending state. project_name: The name of the project to log the run to. Defaults to None, which will use the default project. + process_inputs: A function to filter the inputs to the run. Defaults to None. Returns: @@ -492,6 +501,7 @@ def manual_extra_function(x): client=kwargs.pop("client", None), project_name=kwargs.pop("project_name", None), run_type=run_type, + process_inputs=kwargs.pop("process_inputs", None), ) if kwargs: warnings.warn( diff --git a/python/langsmith/wrappers/_openai.py b/python/langsmith/wrappers/_openai.py index 8a119b2a4..22b4e25c5 100644 --- a/python/langsmith/wrappers/_openai.py +++ b/python/langsmith/wrappers/_openai.py @@ -1,8 +1,20 @@ from __future__ import annotations import functools +import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + DefaultDict, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from langsmith import run_helpers @@ -16,6 +28,28 @@ from openai.types.completion import Completion C = TypeVar("C", bound=Union["OpenAI", "AsyncOpenAI"]) +logger = logging.getLogger(__name__) + + +@functools.lru_cache +def _get_not_given() -> Optional[Type]: + try: + from openai._types import NotGiven + + return NotGiven + except ImportError: + return None + + +def _strip_not_given(d: dict) -> dict: + try: + not_given = _get_not_given() + if not_given is None: + return d + return {k: v for k, v in d.items() if not isinstance(v, not_given)} + except Exception as e: + logger.error(f"Error stripping NotGiven: {e}") + return d def _reduce_choices(choices: List[Choice]) -> dict: @@ -110,15 +144,22 @@ def _get_wrapper(original_create: Callable, name: str, reduce_fn: Callable) -> C @functools.wraps(original_create) def create(*args, stream: bool = False, **kwargs): decorator = run_helpers.traceable( - name=name, run_type="llm", reduce_fn=reduce_fn if stream else None + name=name, + run_type="llm", + reduce_fn=reduce_fn if stream else None, + process_inputs=_strip_not_given, ) return decorator(original_create)(*args, stream=stream, **kwargs) @functools.wraps(original_create) async def acreate(*args, stream: bool = False, **kwargs): + kwargs = _strip_not_given(kwargs) decorator = run_helpers.traceable( - name=name, run_type="llm", reduce_fn=reduce_fn if stream else None + name=name, + run_type="llm", + reduce_fn=reduce_fn if stream else None, + process_inputs=_strip_not_given, ) if stream: # TODO: This slightly alters the output to be a generator instead of the diff --git a/python/tests/unit_tests/test_run_helpers.py b/python/tests/unit_tests/test_run_helpers.py index c43c60038..599d14a20 100644 --- a/python/tests/unit_tests/test_run_helpers.py +++ b/python/tests/unit_tests/test_run_helpers.py @@ -208,7 +208,10 @@ def my_iterator_fn(a, b, d): async def test_traceable_async_iterator(use_next: bool, mock_client: Client) -> None: with patch.dict(os.environ, {"LANGCHAIN_TRACING_V2": "true"}): - @traceable(client=mock_client) + def filter_inputs(kwargs: dict): + return {"a": "FOOOOOO", "b": kwargs["b"], "d": kwargs["d"]} + + @traceable(client=mock_client, process_inputs=filter_inputs) async def my_iterator_fn(a, b, d): for i in range(a + b + d): yield i @@ -234,6 +237,8 @@ async def my_iterator_fn(a, b, d): body = json.loads(call.kwargs["data"]) assert body["post"] assert body["post"][0]["outputs"]["output"] == expected + # Assert the inputs are filtered as expected + assert body["post"][0]["inputs"] == {"a": "FOOOOOO", "b": 2, "d": 3} @patch("langsmith.run_trees.Client", autospec=True)