Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Sep 22, 2023
1 parent 4de1f2e commit 34dd1be
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 13 deletions.
42 changes: 40 additions & 2 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
DefaultDict,
Dict,
Iterable,
Expand Down Expand Up @@ -129,7 +130,44 @@ def close_client(client: httpx.Client) -> None:

def close_async_client(client: httpx.AsyncClient) -> None:
logger.debug("Closing Client._aclient")
asyncio.run(client.aclose())
coro = client.aclose()
try:
# Raises RuntimeError if there is no current event loop.
asyncio.get_running_loop()
loop_running = True
except RuntimeError:
loop_running = False

if loop_running:
# If we try to submit this coroutine to the running loop
# we end up in a deadlock, as we'd have gotten here from a
# running coroutine, which we cannot interrupt to run this one.
# The solution is to create a new loop in a new thread.
with concurrent.futures.ThreadPoolExecutor(1) as executor:
executor.submit(_run_coros, coro).result()
else:
_run_coros(coro)


def _run_coros(coro: Coroutine) -> None:
if hasattr(asyncio, "Runner"):
# Python 3.11+
# Run the coroutines in a new event loop, taking care to
# - install signal handlers
# - run pending tasks scheduled by `coros`
# - close asyncgens and executors
# - close the loop
with asyncio.Runner() as runner:
# Run the coroutine, get the result
runner.run(coro)

# Run pending tasks scheduled by coros until they are all done
while pending := asyncio.all_tasks(runner.get_loop()):
runner.run(asyncio.wait(pending))
else:
# Before Python 3.11 we need to run each coroutine in a new event loop
# as the Runner api is not available.
asyncio.run(coro)


def _validate_api_key_if_hosted(api_url: str, api_key: Optional[str]) -> None:
Expand Down Expand Up @@ -1247,7 +1285,7 @@ def _prepare_create_project(
project_extra: Optional[dict] = None,
upsert: bool = False,
reference_dataset_id: Optional[ID_TYPE] = None,
) -> ls_schemas.TracerSession:
) -> dict:
"""Create a project on the LangSmith API.
Parameters
Expand Down
178 changes: 178 additions & 0 deletions python/tests/integration_tests/test_client_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""LangSmith langchain_client Integration Tests."""
import os
from datetime import datetime
from uuid import uuid4

import pytest
import requests
from freezegun import freeze_time

from langsmith.client import Client
from langsmith.utils import LangSmithConnectionError, LangSmithError


@pytest.fixture
def langchain_client(monkeypatch: pytest.MonkeyPatch) -> Client:
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
return Client()


@pytest.mark.asyncio
async def test_projects(
langchain_client: Client, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test projects."""
project_names = set([project.name for project in langchain_client.list_projects()])
new_project = "__Test Project"
if new_project in project_names:
langchain_client.delete_project(project_name=new_project)
project_names = set(
[project.name for project in langchain_client.list_projects()]
)
assert new_project not in project_names

monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
await langchain_client.acreate_project(
project_name=new_project,
project_extra={"evaluator": "THE EVALUATOR"},
)
project = await langchain_client.aread_project(project_name=new_project)
assert project.name == new_project
project_names = set([sess.name for sess in langchain_client.list_projects()])
assert new_project in project_names
runs = [run async for run in langchain_client.alist_runs(project_name=new_project)]
project_id_runs = [
run async for run in langchain_client.alist_runs(project_id=project.id)
]
assert len(runs) == len(project_id_runs) == 0 # TODO: Add create_run method
langchain_client.delete_project(project_name=new_project)

with pytest.raises(LangSmithError):
await langchain_client.aread_project(project_name=new_project)
assert new_project not in set(
[sess.name for sess in langchain_client.list_projects()]
)
with pytest.raises(LangSmithError):
langchain_client.delete_project(project_name=new_project)


@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_persist_update_run(
monkeypatch: pytest.MonkeyPatch, langchain_client: Client
) -> None:
"""Test the persist and update methods work as expected."""
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
project_name = "__test_persist_update_run"
if project_name in [sess.name for sess in langchain_client.list_projects()]:
langchain_client.delete_project(project_name=project_name)
start_time = datetime.now()
run: dict = dict(
id=uuid4(),
name="test_run",
run_type="llm",
inputs={"text": "hello world"},
project_name=project_name,
api_url=os.getenv("LANGCHAIN_ENDPOINT"),
execution_order=1,
start_time=start_time,
extra={"extra": "extra"},
)
await langchain_client.acreate_run(**run)
run["outputs"] = {"output": ["Hi"]}
run["extra"]["foo"] = "bar"
await langchain_client.aupdate_run(run["id"], **run)
stored_run = await langchain_client.aread_run(run["id"])
assert stored_run.id == run["id"]
assert stored_run.outputs == run["outputs"]
assert stored_run.start_time == run["start_time"]
langchain_client.delete_project(project_name=project_name)


@pytest.mark.asyncio
@pytest.mark.parametrize("uri", ["http://localhost:1981", "http://api.langchain.minus"])
async def test_error_surfaced_invalid_uri(
monkeypatch: pytest.MonkeyPatch, uri: str
) -> None:
monkeypatch.setenv("LANGCHAIN_ENDPOINT", uri)
monkeypatch.setenv("LANGCHAIN_API_KEY", "test")
client = Client()
# expect connect error
with pytest.raises(LangSmithConnectionError):
await client.acreate_run(
"My Run", inputs={"text": "hello world"}, run_type="llm"
)


@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_share_unshare_run(
monkeypatch: pytest.MonkeyPatch, langchain_client: Client
) -> None:
"""Test persisting runs and adding feedback."""
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
run_id = uuid4()
await langchain_client.acreate_run(
name="Test run",
inputs={"input": "hello world"},
run_type="chain",
id=run_id,
)
shared_url = await langchain_client.ashare_run(run_id)
response = requests.get(shared_url)
assert response.status_code == 200
assert await langchain_client.aread_run_shared_link(run_id) == shared_url
await langchain_client.aunshare_run(run_id)


@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_create_run_with_masked_inputs_outputs(
langchain_client: Client, monkeypatch: pytest.MonkeyPatch
) -> None:
project_name = "__test_create_run_with_masked_inputs_outputs"
monkeypatch.setenv("LANGCHAIN_HIDE_INPUTS", "true")
monkeypatch.setenv("LANGCHAIN_HIDE_OUTPUTS", "true")
for project in langchain_client.list_projects():
if project.name == project_name:
langchain_client.delete_project(project_name=project_name)

run_id = "8bac165f-470e-4bf8-baa0-15f2de4cc706"
await langchain_client.acreate_run(
id=run_id,
project_name=project_name,
name="test_run",
run_type="llm",
inputs={"prompt": "hello world"},
outputs={"generation": "hi there"},
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
hide_inputs=True,
hide_outputs=True,
)

run_id2 = "8bac165f-490e-4bf8-baa0-15f2de4cc707"
await langchain_client.acreate_run(
id=run_id2,
project_name=project_name,
name="test_run_2",
run_type="llm",
inputs={"messages": "hello world 2"},
start_time=datetime.utcnow(),
hide_inputs=True,
)

await langchain_client.aupdate_run(
run_id2,
outputs={"generation": "hi there 2"},
end_time=datetime.utcnow(),
hide_outputs=True,
)

run1 = await langchain_client.aread_run(run_id)
assert run1.inputs == {}
assert run1.outputs == {}

run2 = await langchain_client.aread_run(run_id2)
assert run2.inputs == {}
assert run2.outputs == {}
39 changes: 28 additions & 11 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest import mock
from unittest.mock import patch

import httpx
import pytest

from langsmith.client import (
Expand All @@ -15,6 +16,7 @@
_get_api_url,
_is_langchain_hosted,
_is_localhost,
close_async_client,
)
from langsmith.schemas import Example
from langsmith.utils import LangSmithUserError
Expand Down Expand Up @@ -53,8 +55,8 @@ def test_headers(monkeypatch: pytest.MonkeyPatch) -> None:
assert client_no_key._headers == {}


@mock.patch("langsmith.client.requests.Session")
def test_upload_csv(mock_session_cls: mock.Mock) -> None:
@mock.patch("langsmith.client.httpx.Client")
def test_upload_csv(mock_client_cls: mock.Mock) -> None:
dataset_id = str(uuid.uuid4())
example_1 = Example(
id=str(uuid.uuid4()),
Expand All @@ -79,9 +81,9 @@ def test_upload_csv(mock_session_cls: mock.Mock) -> None:
"created_at": _CREATED_AT,
"examples": [example_1, example_2],
}
mock_session = mock.Mock()
mock_session.post.return_value = mock_response
mock_session_cls.return_value = mock_session
mock__client = mock.Mock()
mock__client.post.return_value = mock_response
mock_client_cls.return_value = mock__client

client = Client(
api_url="http://localhost:1984",
Expand Down Expand Up @@ -174,9 +176,9 @@ def test_create_run_unicode():
"qux": "나는\u3000밥을\u3000먹었습니다.",
"는\u3000밥": "나는\u3000밥을\u3000먹었습니다.",
}
session = mock.Mock()
session.request = mock.Mock()
with patch.object(client, "session", session):
_client = mock.Mock()
_client.request = mock.Mock()
with patch.object(client, "_client", _client):
id_ = uuid.uuid4()
client.create_run(
"my_run", inputs=inputs, run_type="llm", execution_order=1, id=id_
Expand All @@ -187,7 +189,7 @@ def test_create_run_unicode():
@pytest.mark.parametrize("source_type", ["api", "model"])
def test_create_feedback_string_source_type(source_type: str):
client = Client(api_url="http://localhost:1984", api_key="123")
session = mock.Mock()
_client = mock.Mock()
request_object = mock.Mock()
request_object.json.return_value = {
"id": uuid.uuid4(),
Expand All @@ -196,11 +198,26 @@ def test_create_feedback_string_source_type(source_type: str):
"modified_at": _CREATED_AT,
"run_id": uuid.uuid4(),
}
session.post.return_value = request_object
with patch.object(client, "session", session):
_client.post.return_value = request_object
with patch.object(client, "_client", _client):
id_ = uuid.uuid4()
client.create_feedback(
id_,
key="Foo",
feedback_source_type=source_type,
)


@pytest.mark.asyncio
async def test_close_async_client() -> None:
async with httpx.AsyncClient() as client:
close_async_client(client)
assert client.is_closed


@pytest.mark.asyncio
async def test_close_async_client_in_event_loop() -> None:
async with httpx.AsyncClient() as client:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, close_async_client, client)
assert client.is_closed

0 comments on commit 34dd1be

Please sign in to comment.