Skip to content

Commit

Permalink
Pre-read smallish multipart requests (#1176)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Nov 6, 2024
1 parent 0a4fb6d commit bb09ced
Show file tree
Hide file tree
Showing 8 changed files with 603 additions and 381 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
- name: Install dependencies
run: |
poetry install --with dev
poetry run pip install -U langchain langchain_anthropic langchain_openai rapidfuzz
poetry run pip install -U langchain langchain_anthropic langchain_openai rapidfuzz pandas
- name: Run Python integration tests
uses: ./.github/actions/python-integration-tests
with:
Expand Down
14 changes: 10 additions & 4 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@
import orjson
import requests
from requests import adapters as requests_adapters
from requests_toolbelt.multipart import MultipartEncoder # type: ignore[import-untyped]
from requests_toolbelt import ( # type: ignore[import-untyped]
multipart as rqtb_multipart,
)
from typing_extensions import TypeGuard
from urllib3.poolmanager import PoolKey # type: ignore[attr-defined]
from urllib3.util import Retry
Expand Down Expand Up @@ -1561,12 +1563,16 @@ def _send_multipart_req(self, acc: MultipartPartsAndContext, *, attempts: int =
for api_url, api_key in self._write_api_urls.items():
for idx in range(1, attempts + 1):
try:
encoder = MultipartEncoder(parts, boundary=BOUNDARY)
encoder = rqtb_multipart.MultipartEncoder(parts, boundary=BOUNDARY)
if encoder.len <= 20_000_000: # ~20 MB
data = encoder.to_string()
else:
data = encoder
self.request_with_retries(
"POST",
f"{api_url}/runs/multipart",
request_kwargs={
"data": encoder,
"data": data,
"headers": {
**self._headers,
X_API_KEY: api_key,
Expand Down Expand Up @@ -2433,7 +2439,7 @@ def _get_optional_tenant_id(self) -> Optional[uuid.UUID]:
self._tenant_id = tracer_session.tenant_id
return self._tenant_id
except Exception as e:
logger.warning(
logger.debug(
"Failed to get tenant ID from LangSmith: %s", repr(e), exc_info=True
)
return None
Expand Down
885 changes: 518 additions & 367 deletions python/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.139"
version = "0.1.140"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <[email protected]>"]
license = "MIT"
Expand Down
5 changes: 3 additions & 2 deletions python/tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,9 @@ async def apredict(inputs: dict) -> dict:
num_repetitions=2,
)
assert len(results) == 20
df = results.to_pandas()
assert len(df) == 20
if _has_pandas():
df = results.to_pandas()
assert len(df) == 20
examples = client.list_examples(dataset_name=dataset_name, as_of="test_version")
all_results = [r async for r in results]
all_examples = []
Expand Down
58 changes: 58 additions & 0 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
import string
import sys
import time
import uuid
from datetime import timedelta
from typing import Any, Callable, Dict
from unittest import mock
from uuid import uuid4

import pytest
from freezegun import freeze_time
from pydantic import BaseModel
from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor

from langsmith.client import ID_TYPE, Client
from langsmith.schemas import DataType
Expand All @@ -24,6 +27,8 @@
get_env_var,
)

logger = logging.getLogger(__name__)


def wait_for(
condition: Callable[[], bool], max_sleep_time: int = 120, sleep_time: int = 3
Expand Down Expand Up @@ -960,3 +965,56 @@ def test_runs_stats():
# We always have stuff in the "default" project...
stats = langchain_client.get_run_stats(project_names=["default"], run_type="llm")
assert stats


def test_slow_run_read_multipart(
langchain_client: Client, caplog: pytest.LogCaptureFixture
):
myobj = {f"key_{i}": f"val_{i}" for i in range(500)}
id_ = str(uuid.uuid4())
current_time = datetime.datetime.now(datetime.timezone.utc).strftime(
"%Y%m%dT%H%M%S%fZ"
)
run_to_create = {
"id": id_,
"session_name": "default",
"name": "trace a root",
"run_type": "chain",
"dotted_order": f"{current_time}{id_}",
"trace_id": id_,
"inputs": myobj,
}

class CB:
def __init__(self):
self.called = 0
self.start_time = None

def __call__(self, monitor: MultipartEncoderMonitor):
self.called += 1
if not self.start_time:
self.start_time = time.time()
logger.debug(
f"[{self.called}]: {monitor.bytes_read} bytes,"
f" {time.time() - self.start_time:.2f} seconds"
" elapsed",
)
if self.called == 1:
time.sleep(6)

def create_encoder(*args, **kwargs):
encoder = MultipartEncoder(*args, **kwargs)
encoder = MultipartEncoderMonitor(encoder, CB())
return encoder

with caplog.at_level(logging.WARNING, logger="langsmith.client"):
with mock.patch(
"langsmith.client.rqtb_multipart.MultipartEncoder", create_encoder
):
langchain_client.create_run(**run_to_create)
time.sleep(1)
start_time = time.time()
while time.time() - start_time < 8:
myobj["key_1"]

assert not caplog.records
13 changes: 9 additions & 4 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import dataclasses
import gc
import io
import itertools
import json
import logging
Expand All @@ -27,7 +28,6 @@
from multipart import MultipartParser, MultipartPart, parse_options_header
from pydantic import BaseModel
from requests import HTTPError
from requests_toolbelt.multipart import MultipartEncoder

import langsmith.env as ls_env
import langsmith.utils as ls_utils
Expand Down Expand Up @@ -349,9 +349,9 @@ def test_create_run_mutate(
assert headers["Content-Type"].startswith("multipart/form-data")
# this is a current implementation detail, if we change implementation
# we update this assertion
assert isinstance(data, MultipartEncoder)
assert isinstance(data, bytes)
boundary = parse_options_header(headers["Content-Type"])[1]["boundary"]
parser = MultipartParser(data, boundary)
parser = MultipartParser(io.BytesIO(data), boundary)
parts.extend(parser.parts())

assert [p.name for p in parts] == [
Expand Down Expand Up @@ -1094,12 +1094,17 @@ def test_batch_ingest_run_splits_large_batches(
assert sum(
[1 for call in mock_session.request.call_args_list if call[0][0] == "POST"]
) in (expected_num_requests, expected_num_requests + 1)

request_bodies = [
op
for call in mock_session.request.call_args_list
for op in (
MultipartParser(
call[1]["data"],
(
io.BytesIO(call[1]["data"])
if isinstance(call[1]["data"], bytes)
else call[1]["data"]
),
parse_options_header(call[1]["headers"]["Content-Type"])[1][
"boundary"
],
Expand Down
5 changes: 3 additions & 2 deletions python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,10 +1729,11 @@ def my_func(
)
),
)
long_content = b"c" * 20_000_000
with tracing_context(enabled=True):
result = my_func(
42,
ls_schemas.Attachment(mime_type="text/plain", data="content1"),
ls_schemas.Attachment(mime_type="text/plain", data=long_content),
("application/octet-stream", "content2"),
langsmith_extra={"client": mock_client},
)
Expand Down Expand Up @@ -1764,7 +1765,7 @@ def my_func(
data for data in datas if data[0] == f"attachment.{trace_id}.att1"
)
assert mime_type1 == "text/plain"
assert content1 == b"content1"
assert content1 == long_content

_, (mime_type2, content2) = next(
data for data in datas if data[0] == f"attachment.{trace_id}.att2"
Expand Down

0 comments on commit bb09ced

Please sign in to comment.