Skip to content

Commit

Permalink
Async trace context manager (#887)
Browse files Browse the repository at this point in the history
Exposed via the same `trace` CM.


Would resolve #882
  • Loading branch information
hinthornw authored Jul 20, 2024
1 parent 8773ab7 commit 0afc018
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 92 deletions.
357 changes: 266 additions & 91 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from langsmith.env import _runtime_env

if TYPE_CHECKING:
from types import TracebackType

from langchain_core.runnables import Runnable

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -685,6 +687,270 @@ def generator_wrapper(
return decorator


class trace:
"""Manage a langsmith run in context.
This class can be used as both a synchronous and asynchronous context manager.
Parameters:
-----------
name : str
Name of the run
run_type : ls_client.RUN_TYPE_T, optional
Type of run (e.g., "chain", "llm", "tool"). Defaults to "chain".
inputs : Optional[Dict], optional
Initial input data for the run
project_name : Optional[str], optional
Associates the run with a specific project, overriding defaults
parent : Optional[Union[run_trees.RunTree, str, Mapping]], optional
Parent run, accepts RunTree, dotted order string, or tracing headers
tags : Optional[List[str]], optional
Categorization labels for the run
metadata : Optional[Mapping[str, Any]], optional
Arbitrary key-value pairs for run annotation
client : Optional[ls_client.Client], optional
LangSmith client for specifying a different tenant,
setting custom headers, or modifying API endpoint
run_id : Optional[ls_client.ID_TYPE], optional
Preset identifier for the run
reference_example_id : Optional[ls_client.ID_TYPE], optional
You typically won't set this. It associates this run with a dataset example.
This is only valid for root runs (not children) in an evaluation context.
exceptions_to_handle : Optional[Tuple[Type[BaseException], ...]], optional
Typically not set. Exception types to ignore in what is sent up to LangSmith
extra : Optional[Dict], optional
Typically not set. Use 'metadata' instead. Extra data to be sent to LangSmith.
Examples:
---------
Synchronous usage:
>>> with trace("My Operation", run_type="tool", tags=["important"]) as run:
... result = "foo" # Do some_operation()
... run.metadata["some-key"] = "some-value"
... run.end(outputs={"result": result})
Asynchronous usage:
>>> async def main():
... async with trace("Async Operation", run_type="tool", tags=["async"]) as run:
... result = "foo" # Can await some_async_operation()
... run.metadata["some-key"] = "some-value"
... # "end" just adds the outputs and sets error to None
... # The actual patching of the run happens when the context exits
... run.end(outputs={"result": result})
>>> asyncio.run(main())
Allowing pytest.skip in a test:
>>> import sys
>>> import pytest
>>> with trace("OS-Specific Test", exceptions_to_handle=(pytest.skip.Exception,)):
... if sys.platform == "win32":
... pytest.skip("Not supported on Windows")
... result = "foo" # e.g., do some unix_specific_operation()
"""

def __init__(
self,
name: str,
run_type: ls_client.RUN_TYPE_T = "chain",
*,
inputs: Optional[Dict] = None,
extra: Optional[Dict] = None,
project_name: Optional[str] = None,
parent: Optional[Union[run_trees.RunTree, str, Mapping]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Mapping[str, Any]] = None,
client: Optional[ls_client.Client] = None,
run_id: Optional[ls_client.ID_TYPE] = None,
reference_example_id: Optional[ls_client.ID_TYPE] = None,
exceptions_to_handle: Optional[Tuple[Type[BaseException], ...]] = None,
**kwargs: Any,
):
"""Initialize the trace context manager.
Warns if unsupported kwargs are passed.
"""
if kwargs:
warnings.warn(
"The `trace` context manager no longer supports the following kwargs: "
f"{sorted(kwargs.keys())}.",
DeprecationWarning,
)
self.name = name
self.run_type = run_type
self.inputs = inputs
self.extra = extra
self.project_name = project_name
self.parent = parent
# The run tree is deprecated. Keeping for backwards compat.
# Will fully merge within parent later.
self.run_tree = kwargs.get("run_tree")
self.tags = tags
self.metadata = metadata
self.client = client
self.run_id = run_id
self.reference_example_id = reference_example_id
self.exceptions_to_handle = exceptions_to_handle
self.new_run: Optional[run_trees.RunTree] = None
self.old_ctx: Optional[dict] = None

def _setup(self) -> run_trees.RunTree:
"""Set up the tracing context and create a new run.
This method initializes the tracing context, merges tags and metadata,
creates a new run (either as a child of an existing run or as a new root run),
and sets up the necessary context variables.
Returns:
run_trees.RunTree: The newly created run.
"""
self.old_ctx = get_tracing_context()
is_disabled = self.old_ctx.get("enabled", True) is False
outer_tags = _TAGS.get()
outer_metadata = _METADATA.get()
parent_run_ = _get_parent_run(
{
"parent": self.parent,
"run_tree": self.run_tree,
"client": self.client,
}
)

tags_ = sorted(set((self.tags or []) + (outer_tags or [])))
metadata = {
**(self.metadata or {}),
**(outer_metadata or {}),
"ls_method": "trace",
}

extra_outer = self.extra or {}
extra_outer["metadata"] = metadata

project_name_ = _get_project_name(self.project_name)

if parent_run_ is not None and not is_disabled:
self.new_run = parent_run_.create_child(
name=self.name,
run_id=self.run_id,
run_type=self.run_type,
extra=extra_outer,
inputs=self.inputs,
tags=tags_,
)
else:
self.new_run = run_trees.RunTree(
name=self.name,
id=ls_client._ensure_uuid(self.run_id),
reference_example_id=ls_client._ensure_uuid(
self.reference_example_id, accept_null=True
),
run_type=self.run_type,
extra=extra_outer,
project_name=project_name_ or "default",
inputs=self.inputs or {},
tags=tags_,
client=self.client, # type: ignore[arg-type]
)

if not is_disabled:
self.new_run.post()
_TAGS.set(tags_)
_METADATA.set(metadata)
_PARENT_RUN_TREE.set(self.new_run)
_PROJECT_NAME.set(project_name_)

return self.new_run

def _teardown(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Clean up the tracing context and finalize the run.
This method handles exceptions, ends the run if necessary,
patches the run if it's not disabled, and resets the tracing context.
Args:
exc_type: The type of the exception that occurred, if any.
exc_value: The exception instance that occurred, if any.
traceback: The traceback object associated with the exception, if any.
"""
if self.new_run is None:
warnings.warn("Tracing context was not set up properly.", RuntimeWarning)
return
if exc_type is not None:
if self.exceptions_to_handle and issubclass(
exc_type, self.exceptions_to_handle
):
tb = None
else:
tb = utils._format_exc()
tb = f"{exc_type.__name__}: {exc_value}\n\n{tb}"
self.new_run.end(error=tb)
if self.old_ctx is not None:
is_disabled = self.old_ctx.get("enabled", True) is False
if not is_disabled:
self.new_run.patch()

_set_tracing_context(self.old_ctx)
else:
warnings.warn("Tracing context was not set up properly.", RuntimeWarning)

def __enter__(self) -> run_trees.RunTree:
"""Enter the context manager synchronously.
Returns:
run_trees.RunTree: The newly created run.
"""
return self._setup()

def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
"""Exit the context manager synchronously.
Args:
exc_type: The type of the exception that occurred, if any.
exc_value: The exception instance that occurred, if any.
traceback: The traceback object associated with the exception, if any.
"""
self._teardown(exc_type, exc_value, traceback)

async def __aenter__(self) -> run_trees.RunTree:
"""Enter the context manager asynchronously.
Returns:
run_trees.RunTree: The newly created run.
"""
return await aitertools.aio_to_thread(self._setup)

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
"""Exit the context manager asynchronously.
Args:
exc_type: The type of the exception that occurred, if any.
exc_value: The exception instance that occurred, if any.
traceback: The traceback object associated with the exception, if any.
"""
if exc_type is not None:
await asyncio.shield(
aitertools.aio_to_thread(self._teardown, exc_type, exc_value, traceback)
)
else:
await aitertools.aio_to_thread(
self._teardown, exc_type, exc_value, traceback
)


def _get_project_name(project_name: Optional[str]) -> Optional[str]:
prt = _PARENT_RUN_TREE.get()
return (
Expand All @@ -698,97 +964,6 @@ def _get_project_name(project_name: Optional[str]) -> Optional[str]:
)


@contextlib.contextmanager
def trace(
name: str,
run_type: ls_client.RUN_TYPE_T = "chain",
*,
inputs: Optional[Dict] = None,
extra: Optional[Dict] = None,
project_name: Optional[str] = None,
parent: Optional[Union[run_trees.RunTree, str, Mapping]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Mapping[str, Any]] = None,
client: Optional[ls_client.Client] = None,
run_id: Optional[ls_client.ID_TYPE] = None,
reference_example_id: Optional[ls_client.ID_TYPE] = None,
exceptions_to_handle: Optional[Tuple[Type[BaseException], ...]] = None,
**kwargs: Any,
) -> Generator[run_trees.RunTree, None, None]:
"""Context manager for creating a run tree."""
if kwargs:
# In case someone was passing an executor before.
warnings.warn(
"The `trace` context manager no longer supports the following kwargs: "
f"{sorted(kwargs.keys())}.",
DeprecationWarning,
)
old_ctx = get_tracing_context()
is_disabled = old_ctx.get("enabled", True) is False
outer_tags = _TAGS.get()
outer_metadata = _METADATA.get()
parent_run_ = _get_parent_run(
{"parent": parent, "run_tree": kwargs.get("run_tree"), "client": client}
)

# Merge context variables
tags_ = sorted(set((tags or []) + (outer_tags or [])))
metadata = {**(metadata or {}), **(outer_metadata or {}), "ls_method": "trace"}

extra_outer = extra or {}
extra_outer["metadata"] = metadata

project_name_ = _get_project_name(project_name)
# If it's disabled, we break the tree
if parent_run_ is not None and not is_disabled:
new_run = parent_run_.create_child(
name=name,
run_id=run_id,
run_type=run_type,
extra=extra_outer,
inputs=inputs,
tags=tags_,
)
else:
new_run = run_trees.RunTree(
name=name,
id=ls_client._ensure_uuid(run_id),
reference_example_id=ls_client._ensure_uuid(
reference_example_id, accept_null=True
),
run_type=run_type,
extra=extra_outer,
project_name=project_name_, # type: ignore[arg-type]
inputs=inputs or {},
tags=tags_,
client=client, # type: ignore[arg-type]
)
if not is_disabled:
new_run.post()
_TAGS.set(tags_)
_METADATA.set(metadata)
_PARENT_RUN_TREE.set(new_run)
_PROJECT_NAME.set(project_name_)

try:
yield new_run
except (Exception, KeyboardInterrupt, BaseException) as e:
if exceptions_to_handle and isinstance(e, exceptions_to_handle):
tb = None
else:
tb = utils._format_exc()
tb = f"{e.__class__.__name__}: {e}\n\n{tb}"
new_run.end(error=tb)
if not is_disabled:
new_run.patch()
raise e
finally:
# Reset the old context
_set_tracing_context(old_ctx)
if not is_disabled:
new_run.patch()


def as_runnable(traceable_fn: Callable) -> Runnable:
"""Convert a function wrapped by the LangSmith @traceable decorator to a Runnable.
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.92"
version = "0.1.93"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <[email protected]>"]
license = "MIT"
Expand Down
Loading

0 comments on commit 0afc018

Please sign in to comment.