diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 2596e4d7c..cc50e2b78 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -20,6 +20,7 @@ AsyncGenerator, AsyncIterator, Callable, + Coroutine, DefaultDict, Dict, Iterable, @@ -129,7 +130,44 @@ def close_client(client: httpx.Client) -> None: def close_async_client(client: httpx.AsyncClient) -> None: logger.debug("Closing Client._aclient") - asyncio.run(client.aclose()) + coro = client.aclose() + try: + # Raises RuntimeError if there is no current event loop. + asyncio.get_running_loop() + loop_running = True + except RuntimeError: + loop_running = False + + if loop_running: + # If we try to submit this coroutine to the running loop + # we end up in a deadlock, as we'd have gotten here from a + # running coroutine, which we cannot interrupt to run this one. + # The solution is to create a new loop in a new thread. + with concurrent.futures.ThreadPoolExecutor(1) as executor: + executor.submit(_run_coros, coro).result() + else: + _run_coros(coro) + + +def _run_coros(coro: Coroutine) -> None: + if hasattr(asyncio, "Runner"): + # Python 3.11+ + # Run the coroutines in a new event loop, taking care to + # - install signal handlers + # - run pending tasks scheduled by `coros` + # - close asyncgens and executors + # - close the loop + with asyncio.Runner() as runner: + # Run the coroutine, get the result + runner.run(coro) + + # Run pending tasks scheduled by coros until they are all done + while pending := asyncio.all_tasks(runner.get_loop()): + runner.run(asyncio.wait(pending)) + else: + # Before Python 3.11 we need to run each coroutine in a new event loop + # as the Runner api is not available. + asyncio.run(coro) def _validate_api_key_if_hosted(api_url: str, api_key: Optional[str]) -> None: @@ -1247,7 +1285,7 @@ def _prepare_create_project( project_extra: Optional[dict] = None, upsert: bool = False, reference_dataset_id: Optional[ID_TYPE] = None, - ) -> ls_schemas.TracerSession: + ) -> dict: """Create a project on the LangSmith API. Parameters diff --git a/python/tests/integration_tests/test_client_async.py b/python/tests/integration_tests/test_client_async.py new file mode 100644 index 000000000..89823e23b --- /dev/null +++ b/python/tests/integration_tests/test_client_async.py @@ -0,0 +1,178 @@ +"""LangSmith langchain_client Integration Tests.""" +import os +from datetime import datetime +from uuid import uuid4 + +import pytest +import requests +from freezegun import freeze_time + +from langsmith.client import Client +from langsmith.utils import LangSmithConnectionError, LangSmithError + + +@pytest.fixture +def langchain_client(monkeypatch: pytest.MonkeyPatch) -> Client: + monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984") + return Client() + + +@pytest.mark.asyncio +async def test_projects( + langchain_client: Client, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test projects.""" + project_names = set([project.name for project in langchain_client.list_projects()]) + new_project = "__Test Project" + if new_project in project_names: + langchain_client.delete_project(project_name=new_project) + project_names = set( + [project.name for project in langchain_client.list_projects()] + ) + assert new_project not in project_names + + monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984") + await langchain_client.acreate_project( + project_name=new_project, + project_extra={"evaluator": "THE EVALUATOR"}, + ) + project = await langchain_client.aread_project(project_name=new_project) + assert project.name == new_project + project_names = set([sess.name for sess in langchain_client.list_projects()]) + assert new_project in project_names + runs = [run async for run in langchain_client.alist_runs(project_name=new_project)] + project_id_runs = [ + run async for run in langchain_client.alist_runs(project_id=project.id) + ] + assert len(runs) == len(project_id_runs) == 0 # TODO: Add create_run method + langchain_client.delete_project(project_name=new_project) + + with pytest.raises(LangSmithError): + await langchain_client.aread_project(project_name=new_project) + assert new_project not in set( + [sess.name for sess in langchain_client.list_projects()] + ) + with pytest.raises(LangSmithError): + langchain_client.delete_project(project_name=new_project) + + +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_persist_update_run( + monkeypatch: pytest.MonkeyPatch, langchain_client: Client +) -> None: + """Test the persist and update methods work as expected.""" + monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984") + project_name = "__test_persist_update_run" + if project_name in [sess.name for sess in langchain_client.list_projects()]: + langchain_client.delete_project(project_name=project_name) + start_time = datetime.now() + run: dict = dict( + id=uuid4(), + name="test_run", + run_type="llm", + inputs={"text": "hello world"}, + project_name=project_name, + api_url=os.getenv("LANGCHAIN_ENDPOINT"), + execution_order=1, + start_time=start_time, + extra={"extra": "extra"}, + ) + await langchain_client.acreate_run(**run) + run["outputs"] = {"output": ["Hi"]} + run["extra"]["foo"] = "bar" + await langchain_client.aupdate_run(run["id"], **run) + stored_run = await langchain_client.aread_run(run["id"]) + assert stored_run.id == run["id"] + assert stored_run.outputs == run["outputs"] + assert stored_run.start_time == run["start_time"] + langchain_client.delete_project(project_name=project_name) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("uri", ["http://localhost:1981", "http://api.langchain.minus"]) +async def test_error_surfaced_invalid_uri( + monkeypatch: pytest.MonkeyPatch, uri: str +) -> None: + monkeypatch.setenv("LANGCHAIN_ENDPOINT", uri) + monkeypatch.setenv("LANGCHAIN_API_KEY", "test") + client = Client() + # expect connect error + with pytest.raises(LangSmithConnectionError): + await client.acreate_run( + "My Run", inputs={"text": "hello world"}, run_type="llm" + ) + + +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_share_unshare_run( + monkeypatch: pytest.MonkeyPatch, langchain_client: Client +) -> None: + """Test persisting runs and adding feedback.""" + monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984") + run_id = uuid4() + await langchain_client.acreate_run( + name="Test run", + inputs={"input": "hello world"}, + run_type="chain", + id=run_id, + ) + shared_url = await langchain_client.ashare_run(run_id) + response = requests.get(shared_url) + assert response.status_code == 200 + assert await langchain_client.aread_run_shared_link(run_id) == shared_url + await langchain_client.aunshare_run(run_id) + + +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_create_run_with_masked_inputs_outputs( + langchain_client: Client, monkeypatch: pytest.MonkeyPatch +) -> None: + project_name = "__test_create_run_with_masked_inputs_outputs" + monkeypatch.setenv("LANGCHAIN_HIDE_INPUTS", "true") + monkeypatch.setenv("LANGCHAIN_HIDE_OUTPUTS", "true") + for project in langchain_client.list_projects(): + if project.name == project_name: + langchain_client.delete_project(project_name=project_name) + + run_id = "8bac165f-470e-4bf8-baa0-15f2de4cc706" + await langchain_client.acreate_run( + id=run_id, + project_name=project_name, + name="test_run", + run_type="llm", + inputs={"prompt": "hello world"}, + outputs={"generation": "hi there"}, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + hide_inputs=True, + hide_outputs=True, + ) + + run_id2 = "8bac165f-490e-4bf8-baa0-15f2de4cc707" + await langchain_client.acreate_run( + id=run_id2, + project_name=project_name, + name="test_run_2", + run_type="llm", + inputs={"messages": "hello world 2"}, + start_time=datetime.utcnow(), + hide_inputs=True, + ) + + await langchain_client.aupdate_run( + run_id2, + outputs={"generation": "hi there 2"}, + end_time=datetime.utcnow(), + hide_outputs=True, + ) + + run1 = await langchain_client.aread_run(run_id) + assert run1.inputs == {} + assert run1.outputs == {} + + run2 = await langchain_client.aread_run(run_id2) + assert run2.inputs == {} + assert run2.outputs == {} diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index ad0ebe266..4b23ea29a 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -7,6 +7,7 @@ from unittest import mock from unittest.mock import patch +import httpx import pytest from langsmith.client import ( @@ -15,6 +16,7 @@ _get_api_url, _is_langchain_hosted, _is_localhost, + close_async_client, ) from langsmith.schemas import Example from langsmith.utils import LangSmithUserError @@ -53,8 +55,8 @@ def test_headers(monkeypatch: pytest.MonkeyPatch) -> None: assert client_no_key._headers == {} -@mock.patch("langsmith.client.requests.Session") -def test_upload_csv(mock_session_cls: mock.Mock) -> None: +@mock.patch("langsmith.client.httpx.Client") +def test_upload_csv(mock_client_cls: mock.Mock) -> None: dataset_id = str(uuid.uuid4()) example_1 = Example( id=str(uuid.uuid4()), @@ -79,9 +81,9 @@ def test_upload_csv(mock_session_cls: mock.Mock) -> None: "created_at": _CREATED_AT, "examples": [example_1, example_2], } - mock_session = mock.Mock() - mock_session.post.return_value = mock_response - mock_session_cls.return_value = mock_session + mock__client = mock.Mock() + mock__client.post.return_value = mock_response + mock_client_cls.return_value = mock__client client = Client( api_url="http://localhost:1984", @@ -174,9 +176,9 @@ def test_create_run_unicode(): "qux": "나는\u3000밥을\u3000먹었습니다.", "는\u3000밥": "나는\u3000밥을\u3000먹었습니다.", } - session = mock.Mock() - session.request = mock.Mock() - with patch.object(client, "session", session): + _client = mock.Mock() + _client.request = mock.Mock() + with patch.object(client, "_client", _client): id_ = uuid.uuid4() client.create_run( "my_run", inputs=inputs, run_type="llm", execution_order=1, id=id_ @@ -187,7 +189,7 @@ def test_create_run_unicode(): @pytest.mark.parametrize("source_type", ["api", "model"]) def test_create_feedback_string_source_type(source_type: str): client = Client(api_url="http://localhost:1984", api_key="123") - session = mock.Mock() + _client = mock.Mock() request_object = mock.Mock() request_object.json.return_value = { "id": uuid.uuid4(), @@ -196,11 +198,26 @@ def test_create_feedback_string_source_type(source_type: str): "modified_at": _CREATED_AT, "run_id": uuid.uuid4(), } - session.post.return_value = request_object - with patch.object(client, "session", session): + _client.post.return_value = request_object + with patch.object(client, "_client", _client): id_ = uuid.uuid4() client.create_feedback( id_, key="Foo", feedback_source_type=source_type, ) + + +@pytest.mark.asyncio +async def test_close_async_client() -> None: + async with httpx.AsyncClient() as client: + close_async_client(client) + assert client.is_closed + + +@pytest.mark.asyncio +async def test_close_async_client_in_event_loop() -> None: + async with httpx.AsyncClient() as client: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, close_async_client, client) + assert client.is_closed