Skip to content

Commit

Permalink
Add support for test caching (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Apr 10, 2024
1 parent e154364 commit de1435c
Show file tree
Hide file tree
Showing 16 changed files with 5,837 additions and 428 deletions.
11 changes: 5 additions & 6 deletions .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve LangSmith. To report a _security_ issue, please instead use the security option below.
description: "Submit a bug report to help us improve LangSmith. To report a _security_ issue, please instead use the security option below."
labels: ["01 Bug Report"]
body:
- type: markdown
Expand All @@ -15,15 +15,15 @@ body:
label: Tracing Method
description: "Select whether you are tracing using LangChain or some other method:"
options:
- label: "With LangChain"
- label: "SDK/Client"
- label: "REST API"
- label: "With LangChain"
- label: "Other"

- type: checkboxes
id: runtime-language
attributes:
label: Runtime Language
label: Language
description: ""
options:
- label: "Python"
Expand All @@ -33,11 +33,10 @@ body:
- type: checkboxes
id: platform-environment
attributes:
label: LangSmith Platform Environment
label: Host
description: "Indicate whether you are connected to the hosted LangSmith platform or running locally."
options:
- label: "Hosted (https://api.smith.langchain.com)"
- label: "Local (http://localhost:1984)"
- label: "Self-hosted"
- label: "Other"

Expand All @@ -48,7 +47,7 @@ body:
description: Please share any other system info with us. You can view this by running `langsmith env` in your terminal.
placeholder: LangSmith SDK version, client runtime information,
validations:
required: false
required: true

- type: textarea
id: reproduction
Expand Down
3 changes: 2 additions & 1 deletion .github/actions/python-integration-tests/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ runs:
- name: Install dependencies
run: |
poetry install --with dev
poetry run pip install -U langchain langchain_anthropic tiktoken rapidfuzz
poetry run pip install -U langchain langchain_anthropic tiktoken rapidfuzz vcrpy
shell: bash
working-directory: python

Expand All @@ -52,6 +52,7 @@ runs:
LANGCHAIN_API_KEY: ${{ inputs.langchain-api-key }}
OPENAI_API_KEY: ${{ inputs.openai-api-key }}
ANTHROPIC_API_KEY: ${{ inputs.anthropic-api-key }}
LANGCHAIN_TEST_CACHE: "tests/cassettes"
run: make doctest
shell: bash
working-directory: python
114 changes: 88 additions & 26 deletions python/langsmith/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import threading
import uuid
import warnings
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload

from typing_extensions import TypedDict
Expand Down Expand Up @@ -69,6 +70,15 @@ def unit(*args: Any, **kwargs: Any) -> Callable:
Returns:
Callable: The decorated test function.
Environment:
- LANGCHAIN_TEST_CACHE: If set, API calls will be cached to disk to
save time and costs during testing. Recommended to commit the
cache files to your repository for faster CI/CD runs.
Requires the 'langsmith[vcr]' package to be installed.
- LANGCHAIN_TEST_TRACKING: Set this variable to the path of a directory
to enable caching of test results. This is useful for re-running tests
without re-executing the code. Requires the 'langsmith[vcr]' package.
Example:
For basic usage, simply decorate a test function with `@unit`:
Expand All @@ -81,6 +91,30 @@ def unit(*args: Any, **kwargs: Any) -> Callable:
or `wrap_*` functions) will be traced within the test case for
improved visibility and debugging.
>>> from langsmith import traceable
>>> @traceable
... def generate_numbers():
... return 3, 4
>>> @unit
... def test_nested():
... # Traced code will be included in the test case
... a, b = generate_numbers()
... assert a + b == 7
LLM calls are expensive! Cache requests by setting
`LANGCHAIN_TEST_CACHE=path/to/cache`. Check in these files to speed up
CI/CD pipelines, so your results only change when your prompt or requested
model changes.
Note that this will require that you install langsmith with the `vcr` extra:
`pip install -U "langsmith[vcr]"`
Caching is faster if you install libyaml. See
https://vcrpy.readthedocs.io/en/latest/installation.html#speed for more details.
>>> os.environ["LANGCHAIN_TEST_CACHE"] = "tests/cassettes"
>>> import openai
>>> from langsmith.wrappers import wrap_openai
>>> @unit
Expand Down Expand Up @@ -145,6 +179,7 @@ def unit(*args: Any, **kwargs: Any) -> Callable:
To run these tests, use the pytest CLI. Or directly run the test functions.
>>> test_addition()
>>> test_nested()
>>> test_with_fixture("Some input")
>>> test_with_expected_output("Some input", "Some")
>>> test_multiplication()
Expand All @@ -156,11 +191,21 @@ def unit(*args: Any, **kwargs: Any) -> Callable:
output_keys=kwargs.pop("output_keys", None),
client=kwargs.pop("client", None),
test_suite_name=kwargs.pop("test_suite_name", None),
cache=ls_utils.get_cache_dir(kwargs.pop("cache", None)),
)
if kwargs:
warnings.warn(f"Unexpected keyword arguments: {kwargs.keys()}")
disable_tracking = os.environ.get("LANGCHAIN_TEST_TRACKING") == "false"
if disable_tracking:
warnings.warn(
"LANGCHAIN_TEST_TRACKING is set to 'false'."
" Skipping LangSmith test tracking."
)

if args and callable(args[0]):
func = args[0]
if disable_tracking:
return func

@functools.wraps(func)
def wrapper(*test_args, **test_kwargs):
Expand All @@ -176,6 +221,8 @@ def wrapper(*test_args, **test_kwargs):
def decorator(func):
@functools.wraps(func)
def wrapper(*test_args, **test_kwargs):
if disable_tracking:
return func(*test_args, **test_kwargs)
_run_test(func, *test_args, **test_kwargs, langtest_extra=langtest_extra)

return wrapper
Expand All @@ -188,7 +235,7 @@ def wrapper(*test_args, **test_kwargs):

def _get_experiment_name() -> str:
# TODO Make more easily configurable
prefix = ls_utils.get_tracer_project(False) or "TestSuite"
prefix = ls_utils.get_tracer_project(False) or "TestSuiteResult"
name = f"{prefix}:{uuid.uuid4().hex[:8]}"
return name

Expand All @@ -199,13 +246,13 @@ def _get_test_suite_name() -> str:
if test_suite_name:
return test_suite_name
if __package__:
return __package__
return __package__ + " Test Suite"
git_info = ls_env.get_git_info()
if git_info:
if git_info["remote_url"]:
repo_name = git_info["remote_url"].split("/")[-1].split(".")[0]
if repo_name:
return repo_name
return repo_name + " Test Suite"
raise ValueError("Please set the LANGCHAIN_TEST_SUITE environment variable.")


Expand All @@ -221,16 +268,19 @@ def _get_test_suite(client: ls_client.Client) -> ls_schemas.Dataset:
def _start_experiment(
client: ls_client.Client,
test_suite: ls_schemas.Dataset,
) -> ls_schemas.TracerSessionResult:
) -> ls_schemas.TracerSession:
experiment_name = _get_experiment_name()
return client.create_project(experiment_name, reference_dataset_id=test_suite.id)


def _get_id(func: Callable, inputs: dict) -> uuid.UUID:
try:
file_path = str(Path(inspect.getfile(func)).relative_to(Path.cwd()))
except ValueError:
# Fall back to module name if file path is not available
file_path = func.__module__
input_json = json.dumps(inputs, sort_keys=True)
identifier = f"{func.__module__}.{func.__name__}_{input_json}"

# Generate a UUID based on the identifier
identifier = f"{file_path}::{func.__name__}{input_json}"
return uuid.uuid5(uuid.NAMESPACE_DNS, identifier)


Expand All @@ -253,7 +303,7 @@ class _LangSmithTestSuite:
def __init__(
self,
client: Optional[ls_client.Client],
experiment: ls_schemas.TracerSessionResult,
experiment: ls_schemas.TracerSession,
dataset: ls_schemas.Dataset,
):
self.client = client or ls_client.Client()
Expand Down Expand Up @@ -338,6 +388,7 @@ class _UTExtra(TypedDict, total=False):
id: Optional[uuid.UUID]
output_keys: Optional[Sequence[str]]
test_suite_name: Optional[str]
cache: Optional[str]


def _ensure_example(
Expand Down Expand Up @@ -367,21 +418,32 @@ def _run_test(func, *test_args, langtest_extra: _UTExtra, **test_kwargs):
)
run_id = uuid.uuid4()

try:
func_ = func if rh.is_traceable_function(func) else rh.traceable(func)
func_(
*test_args,
**test_kwargs,
langsmith_extra={
"run_id": run_id,
"reference_example_id": example_id,
"project_name": test_suite.name,
},
)
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")
def _test():
try:
func_ = func if rh.is_traceable_function(func) else rh.traceable(func)
func_(
*test_args,
**test_kwargs,
langsmith_extra={
"run_id": run_id,
"reference_example_id": example_id,
"project_name": test_suite.name,
},
)
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")

cache_path = (
Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml"
if langtest_extra["cache"]
else None
)
with ls_utils.with_optional_cache(
cache_path, ignore_hosts=[test_suite.client.api_url]
):
_test()
18 changes: 12 additions & 6 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,9 @@ def info(self) -> ls_schemas.LangSmithInfo:
ls_utils.raise_for_status_with_text(response)
self._info = ls_schemas.LangSmithInfo(**response.json())
except BaseException as e:
logger.warning(f"Failed to get info from {self.api_url}: {repr(e)}")
logger.warning(
f"Failed to get info from {self.api_url}: {repr(e)}",
)
self._info = ls_schemas.LangSmithInfo()
return self._info

Expand Down Expand Up @@ -810,7 +812,11 @@ def _get_paginated_list(
params_["limit"] = params_.get("limit", 100)
while True:
params_["offset"] = offset
response = self.request_with_retries("GET", path, params=params_)
response = self.request_with_retries(
"GET",
path,
params=params_,
)
items = response.json()

if not items:
Expand Down Expand Up @@ -1012,13 +1018,13 @@ def _run_transform(
dict: The transformed run object as a dictionary.
"""
if hasattr(run, "dict") and callable(getattr(run, "dict")):
run_create = run.dict() # type: ignore
run_create: dict = run.dict() # type: ignore
else:
run_create = cast(dict, run)
if "id" not in run_create:
run_create["id"] = uuid.uuid4()
elif isinstance(run["id"], str):
run["id"] = uuid.UUID(run["id"])
elif isinstance(run_create["id"], str):
run_create["id"] = uuid.UUID(run_create["id"])
if "inputs" in run_create and run_create["inputs"] is not None:
run_create["inputs"] = self._hide_run_inputs(run_create["inputs"])
if "outputs" in run_create and run_create["outputs"] is not None:
Expand Down Expand Up @@ -3161,7 +3167,7 @@ def _resolve_run_id(
if isinstance(run, (str, uuid.UUID)):
run_ = self.read_run(run, load_child_runs=load_child_runs)
else:
run_ = run
run_ = cast(ls_schemas.Run, run)
return run_

def _resolve_example_id(
Expand Down
Loading

0 comments on commit de1435c

Please sign in to comment.