Skip to content

Commit

Permalink
Wfh/strip not given (#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Feb 24, 2024
1 parent 54ff955 commit d79fcf1
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.PHONY: tests lint format

tests:
poetry run pytest tests/unit_tests
poetry run pytest -n auto --durations=10 tests/unit_tests

tests_watch:
poetry run ptw --now . -- -vv -x tests/unit_tests
Expand Down
10 changes: 10 additions & 0 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class _ContainerInput(TypedDict, total=False):
reduce_fn: Optional[Callable]
project_name: Optional[str]
run_type: ls_client.RUN_TYPE_T
process_inputs: Optional[Callable[[dict], dict]]


def _container_end(
Expand Down Expand Up @@ -207,6 +208,12 @@ def _setup_run(
except TypeError as e:
logger.debug(f"Failed to infer inputs for {name_}: {e}")
inputs = {"args": args, "kwargs": kwargs}
process_inputs = container_input.get("process_inputs")
if process_inputs:
try:
inputs = process_inputs(inputs)
except Exception as e:
logger.error(f"Failed to filter inputs for {name_}: {e}")
outer_tags = _TAGS.get()
tags_ = (langsmith_extra.get("tags") or []) + (outer_tags or [])
_TAGS.set(tags_)
Expand Down Expand Up @@ -325,6 +332,7 @@ def traceable(
client: Optional[ls_client.Client] = None,
reduce_fn: Optional[Callable] = None,
project_name: Optional[str] = None,
process_inputs: Optional[Callable[[dict], dict]] = None,
) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]:
...

Expand All @@ -350,6 +358,7 @@ def traceable(
called, and the run itself will be stuck in a pending state.
project_name: The name of the project to log the run to. Defaults to None,
which will use the default project.
process_inputs: A function to filter the inputs to the run. Defaults to None.
Returns:
Expand Down Expand Up @@ -492,6 +501,7 @@ def manual_extra_function(x):
client=kwargs.pop("client", None),
project_name=kwargs.pop("project_name", None),
run_type=run_type,
process_inputs=kwargs.pop("process_inputs", None),
)
if kwargs:
warnings.warn(
Expand Down
47 changes: 44 additions & 3 deletions python/langsmith/wrappers/_openai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from __future__ import annotations

import functools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
)

from langsmith import run_helpers

Expand All @@ -16,6 +28,28 @@
from openai.types.completion import Completion

C = TypeVar("C", bound=Union["OpenAI", "AsyncOpenAI"])
logger = logging.getLogger(__name__)


@functools.lru_cache
def _get_not_given() -> Optional[Type]:
try:
from openai._types import NotGiven

return NotGiven
except ImportError:
return None


def _strip_not_given(d: dict) -> dict:
try:
not_given = _get_not_given()
if not_given is None:
return d
return {k: v for k, v in d.items() if not isinstance(v, not_given)}
except Exception as e:
logger.error(f"Error stripping NotGiven: {e}")
return d


def _reduce_choices(choices: List[Choice]) -> dict:
Expand Down Expand Up @@ -110,15 +144,22 @@ def _get_wrapper(original_create: Callable, name: str, reduce_fn: Callable) -> C
@functools.wraps(original_create)
def create(*args, stream: bool = False, **kwargs):
decorator = run_helpers.traceable(
name=name, run_type="llm", reduce_fn=reduce_fn if stream else None
name=name,
run_type="llm",
reduce_fn=reduce_fn if stream else None,
process_inputs=_strip_not_given,
)

return decorator(original_create)(*args, stream=stream, **kwargs)

@functools.wraps(original_create)
async def acreate(*args, stream: bool = False, **kwargs):
kwargs = _strip_not_given(kwargs)
decorator = run_helpers.traceable(
name=name, run_type="llm", reduce_fn=reduce_fn if stream else None
name=name,
run_type="llm",
reduce_fn=reduce_fn if stream else None,
process_inputs=_strip_not_given,
)
if stream:
# TODO: This slightly alters the output to be a generator instead of the
Expand Down
7 changes: 6 additions & 1 deletion python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def my_iterator_fn(a, b, d):
async def test_traceable_async_iterator(use_next: bool, mock_client: Client) -> None:
with patch.dict(os.environ, {"LANGCHAIN_TRACING_V2": "true"}):

@traceable(client=mock_client)
def filter_inputs(kwargs: dict):
return {"a": "FOOOOOO", "b": kwargs["b"], "d": kwargs["d"]}

@traceable(client=mock_client, process_inputs=filter_inputs)
async def my_iterator_fn(a, b, d):
for i in range(a + b + d):
yield i
Expand All @@ -234,6 +237,8 @@ async def my_iterator_fn(a, b, d):
body = json.loads(call.kwargs["data"])
assert body["post"]
assert body["post"][0]["outputs"]["output"] == expected
# Assert the inputs are filtered as expected
assert body["post"][0]["inputs"] == {"a": "FOOOOOO", "b": 2, "d": 3}


@patch("langsmith.run_trees.Client", autospec=True)
Expand Down

0 comments on commit d79fcf1

Please sign in to comment.