Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement "streaming" multipart requests to multipart endpoint #1055

Merged
merged 8 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 189 additions & 3 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import warnings
import weakref
from dataclasses import dataclass, field
from inspect import signature
from queue import Empty, PriorityQueue, Queue
from typing import (
TYPE_CHECKING,
Expand All @@ -63,7 +64,9 @@
import orjson
import requests
from requests import adapters as requests_adapters
from requests_toolbelt.multipart import MultipartEncoder # type: ignore[import-untyped]
from typing_extensions import TypeGuard
from urllib3.poolmanager import PoolKey # type: ignore[attr-defined]
from urllib3.util import Retry

import langsmith
Expand Down Expand Up @@ -92,6 +95,9 @@ class ZoneInfo: # type: ignore[no-redef]
X_API_KEY = "x-api-key"
WARNED_ATTACHMENTS = False
EMPTY_SEQ: tuple[Dict, ...] = ()
BOUNDARY = uuid.uuid4().hex
MultipartParts = List[Tuple[str, Tuple[None, bytes, str]]]
URLLIB3_SUPPORTS_BLOCKSIZE = "key_blocksize" in signature(PoolKey).parameters


def _parse_token_or_url(
Expand Down Expand Up @@ -459,7 +465,9 @@ def __init__(
super().__init__(pool_connections, pool_maxsize, max_retries, pool_block)

def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
pool_kwargs["blocksize"] = self._blocksize
if URLLIB3_SUPPORTS_BLOCKSIZE:
# urllib3 before 2.0 doesn't support blocksize
pool_kwargs["blocksize"] = self._blocksize
return super().init_poolmanager(connections, maxsize, block, **pool_kwargs)


Expand Down Expand Up @@ -1538,6 +1546,178 @@ def _post_batch_ingest_runs(self, body: bytes, *, _context: str):
except Exception:
logger.warning(f"Failed to batch ingest runs: {repr(e)}")

def multipart_ingest_runs(
self,
create: Optional[
Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]]
] = None,
update: Optional[
Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]]
] = None,
*,
pre_sampled: bool = False,
):
"""Batch ingest/upsert multiple runs in the Langsmith system.

Args:
create (Optional[Sequence[Union[ls_schemas.Run, RunLikeDict]]]):
A sequence of `Run` objects or equivalent dictionaries representing
runs to be created / posted.
update (Optional[Sequence[Union[ls_schemas.Run, RunLikeDict]]]):
A sequence of `Run` objects or equivalent dictionaries representing
runs that have already been created and should be updated / patched.
pre_sampled (bool, optional): Whether the runs have already been subject
to sampling, and therefore should not be sampled again.
Defaults to False.

Returns:
None: If both `create` and `update` are None.

Raises:
LangsmithAPIError: If there is an error in the API request.

Note:
- The run objects MUST contain the dotted_order and trace_id fields
to be accepted by the API.
"""
if not create and not update:
return
# transform and convert to dicts
all_attachments: Dict[str, ls_schemas.Attachments] = {}
create_dicts = [
self._run_transform(run, attachments_collector=all_attachments)
for run in create or EMPTY_SEQ
]
update_dicts = [
self._run_transform(run, update=True, attachments_collector=all_attachments)
for run in update or EMPTY_SEQ
]
# require trace_id and dotted_order
if create_dicts:
for run in create_dicts:
if not run.get("trace_id") or not run.get("dotted_order"):
raise ls_utils.LangSmithUserError(
"Batch ingest requires trace_id and dotted_order to be set."
)
else:
del run
if update_dicts:
for run in update_dicts:
if not run.get("trace_id") or not run.get("dotted_order"):
raise ls_utils.LangSmithUserError(
"Batch ingest requires trace_id and dotted_order to be set."
)
else:
del run
# combine post and patch dicts where possible
if update_dicts and create_dicts:
create_by_id = {run["id"]: run for run in create_dicts}
standalone_updates: list[dict] = []
for run in update_dicts:
if run["id"] in create_by_id:
for k, v in run.items():
if v is not None:
create_by_id[run["id"]][k] = v
else:
standalone_updates.append(run)
else:
del run
update_dicts = standalone_updates
# filter out runs that are not sampled
if not pre_sampled:
create_dicts = self._filter_for_sampling(create_dicts)
update_dicts = self._filter_for_sampling(update_dicts, patch=True)
if not create_dicts and not update_dicts:
return
# insert runtime environment
self._insert_runtime_env(create_dicts)
self._insert_runtime_env(update_dicts)
# check size limit
size_limit_bytes = (self.info.batch_ingest_config or {}).get(
"size_limit_bytes"
) or _SIZE_LIMIT_BYTES
# send the runs in multipart requests
acc_size = 0
acc_context: List[str] = []
acc_parts: MultipartParts = []
for event, payloads in (("post", create_dicts), ("patch", update_dicts)):
for payload in payloads:
parts: MultipartParts = []
# collect fields to be sent as separate parts
fields = [
("inputs", payload.pop("inputs", None)),
("outputs", payload.pop("outputs", None)),
("serialized", payload.pop("serialized", None)),
("events", payload.pop("events", None)),
]
# encode the main run payload
parts.append(
(
f"{event}.{payload['id']}",
(None, _dumps_json(payload), "application/json"),
)
)
# encode the fields we collected
for key, value in fields:
if value is None:
continue
parts.append(
(
f"{event}.{payload['id']}.{key}",
(None, _dumps_json(value), "application/json"),
),
)
# encode the attachments
if attachments := all_attachments.pop(payload["id"], None):
for n, (ct, ba) in attachments.items():
parts.append(
(f"attachment.{payload['id']}.{n}", (None, ba, ct))
)
# calculate the size of the parts
size = sum(len(p[1][1]) for p in parts)
# compute context
context = f"trace={payload.get('trace_id')},id={payload.get('id')}"
# if next size would exceed limit, send the current parts
if acc_size + size > size_limit_bytes:
self._send_multipart_req(acc_parts, _context="; ".join(acc_context))
acc_parts.clear()
acc_context.clear()
acc_size = 0
# accumulate the parts
acc_size += size
acc_parts.extend(parts)
acc_context.append(context)
# send the remaining parts
if acc_parts:
self._send_multipart_req(acc_parts, _context="; ".join(acc_context))

def _send_multipart_req(self, parts: MultipartParts, *, _context: str):
for api_url, api_key in self._write_api_urls.items():
try:
encoder = MultipartEncoder(parts, boundary=BOUNDARY)
self.request_with_retries(
"POST",
f"{api_url}/runs/multipart",
request_kwargs={
"data": encoder,
"headers": {
**self._headers,
X_API_KEY: api_key,
"Content-Type": encoder.content_type,
},
},
to_ignore=(ls_utils.LangSmithConflictError,),
stop_after_attempt=3,
_context=_context,
)
except Exception as e:
try:
exc_desc_lines = traceback.format_exception_only(type(e), e)
exc_desc = "".join(exc_desc_lines).rstrip()
logger.warning(f"Failed to multipart ingest runs: {exc_desc}")
except Exception:
logger.warning(f"Failed to multipart ingest runs: {repr(e)}")

def update_run(
self,
run_id: ID_TYPE,
Expand Down Expand Up @@ -5593,7 +5773,10 @@ def _tracing_thread_handle_batch(
create = [it.item for it in batch if it.action == "create"]
update = [it.item for it in batch if it.action == "update"]
try:
client.batch_ingest_runs(create=create, update=update, pre_sampled=True)
if use_multipart:
client.multipart_ingest_runs(create=create, update=update, pre_sampled=True)
else:
client.batch_ingest_runs(create=create, update=update, pre_sampled=True)
except Exception:
logger.error("Error in tracing queue", exc_info=True)
# exceptions are logged elsewhere, but we need to make sure the
Expand Down Expand Up @@ -5642,7 +5825,10 @@ def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
size_limit: int = batch_ingest_config["size_limit"]
scale_up_nthreads_limit: int = batch_ingest_config["scale_up_nthreads_limit"]
scale_up_qsize_trigger: int = batch_ingest_config["scale_up_qsize_trigger"]
use_multipart: bool = batch_ingest_config.get("use_multipart_endpoint", False)
if multipart_override := os.getenv("LANGSMITH_FF_MULTIPART"):
use_multipart = multipart_override.lower() in ["1", "true"]
else:
use_multipart = batch_ingest_config.get("use_multipart_endpoint", False)

sub_threads: List[threading.Thread] = []
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
Expand Down
16 changes: 15 additions & 1 deletion python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pytest-rerunfailures = "^14.0"
pytest-socket = "^0.7.0"
pyperf = "^2.7.0"
py-spy = "^0.3.14"
multipart = "^1.0.0"

[tool.poetry.group.lint.dependencies]
openai = "^1.10"
Expand Down
12 changes: 10 additions & 2 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,10 @@ def test_create_chat_example(
langchain_client.delete_dataset(dataset_id=dataset.id)


def test_batch_ingest_runs(langchain_client: Client) -> None:
@pytest.mark.parametrize("use_multipart_endpoint", [True, False])
def test_batch_ingest_runs(
langchain_client: Client, use_multipart_endpoint: bool
) -> None:
_session = "__test_batch_ingest_runs"
trace_id = uuid4()
trace_id_2 = uuid4()
Expand Down Expand Up @@ -669,7 +672,12 @@ def test_batch_ingest_runs(langchain_client: Client) -> None:
"outputs": {"output1": 4, "output2": 5},
},
]
langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update)
if use_multipart_endpoint:
langchain_client.multipart_ingest_runs(
create=runs_to_create, update=runs_to_update
)
else:
langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update)
runs = []
wait = 4
for _ in range(15):
Expand Down
Loading
Loading