Skip to content

Commit

Permalink
Implement "streaming" multipart requests to multipart endpoint (#1055)
Browse files Browse the repository at this point in the history
- Use streaming multipart encoder form requests_toolbelt
- Currently dump each part to json before sending the request as that's
the only way to enforce the payload size limit
- When we lift payload size limit we should implement true streaming
encoding, where each part is only encoded immediately before being sent
over the connection, and use transfer-encoding: chunked
  • Loading branch information
nfcampos authored Oct 1, 2024
2 parents a20c94c + 08ec720 commit 1f1b76a
Show file tree
Hide file tree
Showing 5 changed files with 384 additions and 64 deletions.
192 changes: 189 additions & 3 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import warnings
import weakref
from dataclasses import dataclass, field
from inspect import signature
from queue import Empty, PriorityQueue, Queue
from typing import (
TYPE_CHECKING,
Expand All @@ -63,7 +64,9 @@
import orjson
import requests
from requests import adapters as requests_adapters
from requests_toolbelt.multipart import MultipartEncoder # type: ignore[import-untyped]
from typing_extensions import TypeGuard
from urllib3.poolmanager import PoolKey # type: ignore[attr-defined]
from urllib3.util import Retry

import langsmith
Expand Down Expand Up @@ -92,6 +95,9 @@ class ZoneInfo: # type: ignore[no-redef]
X_API_KEY = "x-api-key"
WARNED_ATTACHMENTS = False
EMPTY_SEQ: tuple[Dict, ...] = ()
BOUNDARY = uuid.uuid4().hex
MultipartParts = List[Tuple[str, Tuple[None, bytes, str]]]
URLLIB3_SUPPORTS_BLOCKSIZE = "key_blocksize" in signature(PoolKey).parameters


def _parse_token_or_url(
Expand Down Expand Up @@ -459,7 +465,9 @@ def __init__(
super().__init__(pool_connections, pool_maxsize, max_retries, pool_block)

def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
pool_kwargs["blocksize"] = self._blocksize
if URLLIB3_SUPPORTS_BLOCKSIZE:
# urllib3 before 2.0 doesn't support blocksize
pool_kwargs["blocksize"] = self._blocksize
return super().init_poolmanager(connections, maxsize, block, **pool_kwargs)


Expand Down Expand Up @@ -1538,6 +1546,178 @@ def _post_batch_ingest_runs(self, body: bytes, *, _context: str):
except Exception:
logger.warning(f"Failed to batch ingest runs: {repr(e)}")

def multipart_ingest_runs(
self,
create: Optional[
Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]]
] = None,
update: Optional[
Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]]
] = None,
*,
pre_sampled: bool = False,
):
"""Batch ingest/upsert multiple runs in the Langsmith system.
Args:
create (Optional[Sequence[Union[ls_schemas.Run, RunLikeDict]]]):
A sequence of `Run` objects or equivalent dictionaries representing
runs to be created / posted.
update (Optional[Sequence[Union[ls_schemas.Run, RunLikeDict]]]):
A sequence of `Run` objects or equivalent dictionaries representing
runs that have already been created and should be updated / patched.
pre_sampled (bool, optional): Whether the runs have already been subject
to sampling, and therefore should not be sampled again.
Defaults to False.
Returns:
None: If both `create` and `update` are None.
Raises:
LangsmithAPIError: If there is an error in the API request.
Note:
- The run objects MUST contain the dotted_order and trace_id fields
to be accepted by the API.
"""
if not create and not update:
return
# transform and convert to dicts
all_attachments: Dict[str, ls_schemas.Attachments] = {}
create_dicts = [
self._run_transform(run, attachments_collector=all_attachments)
for run in create or EMPTY_SEQ
]
update_dicts = [
self._run_transform(run, update=True, attachments_collector=all_attachments)
for run in update or EMPTY_SEQ
]
# require trace_id and dotted_order
if create_dicts:
for run in create_dicts:
if not run.get("trace_id") or not run.get("dotted_order"):
raise ls_utils.LangSmithUserError(
"Batch ingest requires trace_id and dotted_order to be set."
)
else:
del run
if update_dicts:
for run in update_dicts:
if not run.get("trace_id") or not run.get("dotted_order"):
raise ls_utils.LangSmithUserError(
"Batch ingest requires trace_id and dotted_order to be set."
)
else:
del run
# combine post and patch dicts where possible
if update_dicts and create_dicts:
create_by_id = {run["id"]: run for run in create_dicts}
standalone_updates: list[dict] = []
for run in update_dicts:
if run["id"] in create_by_id:
for k, v in run.items():
if v is not None:
create_by_id[run["id"]][k] = v
else:
standalone_updates.append(run)
else:
del run
update_dicts = standalone_updates
# filter out runs that are not sampled
if not pre_sampled:
create_dicts = self._filter_for_sampling(create_dicts)
update_dicts = self._filter_for_sampling(update_dicts, patch=True)
if not create_dicts and not update_dicts:
return
# insert runtime environment
self._insert_runtime_env(create_dicts)
self._insert_runtime_env(update_dicts)
# check size limit
size_limit_bytes = (self.info.batch_ingest_config or {}).get(
"size_limit_bytes"
) or _SIZE_LIMIT_BYTES
# send the runs in multipart requests
acc_size = 0
acc_context: List[str] = []
acc_parts: MultipartParts = []
for event, payloads in (("post", create_dicts), ("patch", update_dicts)):
for payload in payloads:
parts: MultipartParts = []
# collect fields to be sent as separate parts
fields = [
("inputs", payload.pop("inputs", None)),
("outputs", payload.pop("outputs", None)),
("serialized", payload.pop("serialized", None)),
("events", payload.pop("events", None)),
]
# encode the main run payload
parts.append(
(
f"{event}.{payload['id']}",
(None, _dumps_json(payload), "application/json"),
)
)
# encode the fields we collected
for key, value in fields:
if value is None:
continue
parts.append(
(
f"{event}.{payload['id']}.{key}",
(None, _dumps_json(value), "application/json"),
),
)
# encode the attachments
if attachments := all_attachments.pop(payload["id"], None):
for n, (ct, ba) in attachments.items():
parts.append(
(f"attachment.{payload['id']}.{n}", (None, ba, ct))
)
# calculate the size of the parts
size = sum(len(p[1][1]) for p in parts)
# compute context
context = f"trace={payload.get('trace_id')},id={payload.get('id')}"
# if next size would exceed limit, send the current parts
if acc_size + size > size_limit_bytes:
self._send_multipart_req(acc_parts, _context="; ".join(acc_context))
acc_parts.clear()
acc_context.clear()
acc_size = 0
# accumulate the parts
acc_size += size
acc_parts.extend(parts)
acc_context.append(context)
# send the remaining parts
if acc_parts:
self._send_multipart_req(acc_parts, _context="; ".join(acc_context))

def _send_multipart_req(self, parts: MultipartParts, *, _context: str):
for api_url, api_key in self._write_api_urls.items():
try:
encoder = MultipartEncoder(parts, boundary=BOUNDARY)
self.request_with_retries(
"POST",
f"{api_url}/runs/multipart",
request_kwargs={
"data": encoder,
"headers": {
**self._headers,
X_API_KEY: api_key,
"Content-Type": encoder.content_type,
},
},
to_ignore=(ls_utils.LangSmithConflictError,),
stop_after_attempt=3,
_context=_context,
)
except Exception as e:
try:
exc_desc_lines = traceback.format_exception_only(type(e), e)
exc_desc = "".join(exc_desc_lines).rstrip()
logger.warning(f"Failed to multipart ingest runs: {exc_desc}")
except Exception:
logger.warning(f"Failed to multipart ingest runs: {repr(e)}")

def update_run(
self,
run_id: ID_TYPE,
Expand Down Expand Up @@ -5593,7 +5773,10 @@ def _tracing_thread_handle_batch(
create = [it.item for it in batch if it.action == "create"]
update = [it.item for it in batch if it.action == "update"]
try:
client.batch_ingest_runs(create=create, update=update, pre_sampled=True)
if use_multipart:
client.multipart_ingest_runs(create=create, update=update, pre_sampled=True)
else:
client.batch_ingest_runs(create=create, update=update, pre_sampled=True)
except Exception:
logger.error("Error in tracing queue", exc_info=True)
# exceptions are logged elsewhere, but we need to make sure the
Expand Down Expand Up @@ -5642,7 +5825,10 @@ def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
size_limit: int = batch_ingest_config["size_limit"]
scale_up_nthreads_limit: int = batch_ingest_config["scale_up_nthreads_limit"]
scale_up_qsize_trigger: int = batch_ingest_config["scale_up_qsize_trigger"]
use_multipart: bool = batch_ingest_config.get("use_multipart_endpoint", False)
if multipart_override := os.getenv("LANGSMITH_FF_MULTIPART"):
use_multipart = multipart_override.lower() in ["1", "true"]
else:
use_multipart = batch_ingest_config.get("use_multipart_endpoint", False)

sub_threads: List[threading.Thread] = []
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
Expand Down
16 changes: 15 additions & 1 deletion python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pytest-rerunfailures = "^14.0"
pytest-socket = "^0.7.0"
pyperf = "^2.7.0"
py-spy = "^0.3.14"
multipart = "^1.0.0"

[tool.poetry.group.lint.dependencies]
openai = "^1.10"
Expand Down
12 changes: 10 additions & 2 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,10 @@ def test_create_chat_example(
langchain_client.delete_dataset(dataset_id=dataset.id)


def test_batch_ingest_runs(langchain_client: Client) -> None:
@pytest.mark.parametrize("use_multipart_endpoint", [True, False])
def test_batch_ingest_runs(
langchain_client: Client, use_multipart_endpoint: bool
) -> None:
_session = "__test_batch_ingest_runs"
trace_id = uuid4()
trace_id_2 = uuid4()
Expand Down Expand Up @@ -669,7 +672,12 @@ def test_batch_ingest_runs(langchain_client: Client) -> None:
"outputs": {"output1": 4, "output2": 5},
},
]
langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update)
if use_multipart_endpoint:
langchain_client.multipart_ingest_runs(
create=runs_to_create, update=runs_to_update
)
else:
langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update)
runs = []
wait = 4
for _ in range(15):
Expand Down
Loading

0 comments on commit 1f1b76a

Please sign in to comment.