From a638687017f7c0adb970b9a0086d651294277a0c Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 19 Nov 2024 06:49:38 -0800 Subject: [PATCH] Revert "orjson optional wip (#1223)" (#1228) This reverts commit 3fce7b72c8d38b138666a7363baa5c78d9e9dd19. --- .../langsmith/_internal/_background_thread.py | 24 ++---- python/langsmith/_internal/_operations.py | 9 +- python/langsmith/_internal/_orjson.py | 84 ------------------- python/langsmith/_internal/_serde.py | 18 ++-- python/langsmith/_testing.py | 4 +- python/langsmith/client.py | 31 +++---- python/poetry.lock | 4 +- python/pyproject.toml | 2 +- python/tests/unit_tests/test_client.py | 8 +- python/tests/unit_tests/test_operations.py | 9 +- 10 files changed, 46 insertions(+), 147 deletions(-) delete mode 100644 python/langsmith/_internal/_orjson.py diff --git a/python/langsmith/_internal/_background_thread.py b/python/langsmith/_internal/_background_thread.py index 844851996..b6aee1f4e 100644 --- a/python/langsmith/_internal/_background_thread.py +++ b/python/langsmith/_internal/_background_thread.py @@ -155,25 +155,13 @@ def tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None: # 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached num_known_refs = 3 - def keep_thread_active() -> bool: - # if `client.cleanup()` was called, stop thread - if client and client._manual_cleanup: - return False - if not threading.main_thread().is_alive(): - # main thread is dead. should not be active - return False - - if hasattr(sys, "getrefcount"): - # check if client refs count indicates we're the only remaining - # reference to the client - return sys.getrefcount(client) > num_known_refs + len(sub_threads) - else: - # in PyPy, there is no sys.getrefcount attribute - # for now, keep thread alive - return True - # loop until - while keep_thread_active(): + while ( + # the main thread dies + threading.main_thread().is_alive() + # or we're the only remaining reference to the client + and sys.getrefcount(client) > num_known_refs + len(sub_threads) + ): for thread in sub_threads: if not thread.is_alive(): sub_threads.remove(thread) diff --git a/python/langsmith/_internal/_operations.py b/python/langsmith/_internal/_operations.py index 66decff0f..e1e99d6e2 100644 --- a/python/langsmith/_internal/_operations.py +++ b/python/langsmith/_internal/_operations.py @@ -5,8 +5,9 @@ import uuid from typing import Literal, Optional, Union, cast +import orjson + from langsmith import schemas as ls_schemas -from langsmith._internal import _orjson from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext from langsmith._internal._serde import dumps_json as _dumps_json @@ -168,12 +169,12 @@ def combine_serialized_queue_operations( if op._none is not None and op._none != create_op._none: # TODO optimize this more - this would currently be slowest # for large payloads - create_op_dict = _orjson.loads(create_op._none) + create_op_dict = orjson.loads(create_op._none) op_dict = { - k: v for k, v in _orjson.loads(op._none).items() if v is not None + k: v for k, v in orjson.loads(op._none).items() if v is not None } create_op_dict.update(op_dict) - create_op._none = _orjson.dumps(create_op_dict) + create_op._none = orjson.dumps(create_op_dict) if op.inputs is not None: create_op.inputs = op.inputs diff --git a/python/langsmith/_internal/_orjson.py b/python/langsmith/_internal/_orjson.py deleted file mode 100644 index ecd9e20bc..000000000 --- a/python/langsmith/_internal/_orjson.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Stubs for orjson operations, compatible with PyPy via a json fallback.""" - -try: - from orjson import ( - OPT_NON_STR_KEYS, - OPT_SERIALIZE_DATACLASS, - OPT_SERIALIZE_NUMPY, - OPT_SERIALIZE_UUID, - Fragment, - JSONDecodeError, - dumps, - loads, - ) - -except ImportError: - import dataclasses - import json - import uuid - from typing import Any, Callable, Optional - - OPT_NON_STR_KEYS = 1 - OPT_SERIALIZE_DATACLASS = 2 - OPT_SERIALIZE_NUMPY = 4 - OPT_SERIALIZE_UUID = 8 - - class Fragment: # type: ignore - def __init__(self, payloadb: bytes): - self.payloadb = payloadb - - from json import JSONDecodeError # type: ignore - - def dumps( # type: ignore - obj: Any, - /, - default: Optional[Callable[[Any], Any]] = None, - option: int = 0, - ) -> bytes: # type: ignore - # for now, don't do anything for this case because `json.dumps` - # automatically encodes non-str keys as str by default, unlike orjson - # enable_non_str_keys = bool(option & OPT_NON_STR_KEYS) - - enable_serialize_numpy = bool(option & OPT_SERIALIZE_NUMPY) - enable_serialize_dataclass = bool(option & OPT_SERIALIZE_DATACLASS) - enable_serialize_uuid = bool(option & OPT_SERIALIZE_UUID) - - class CustomEncoder(json.JSONEncoder): # type: ignore - def encode(self, o: Any) -> str: - if isinstance(o, Fragment): - return o.payloadb.decode("utf-8") # type: ignore - return super().encode(o) - - def default(self, o: Any) -> Any: - if enable_serialize_uuid and isinstance(o, uuid.UUID): - return str(o) - if enable_serialize_numpy and hasattr(o, "tolist"): - # even objects like np.uint16(15) have a .tolist() function - return o.tolist() - if ( - enable_serialize_dataclass - and dataclasses.is_dataclass(o) - and not isinstance(o, type) - ): - return dataclasses.asdict(o) - if default is not None: - return default(o) - - return super().default(o) - - return json.dumps(obj, cls=CustomEncoder).encode("utf-8") - - def loads(payload: bytes, /) -> Any: # type: ignore - return json.loads(payload) - - -__all__ = [ - "loads", - "dumps", - "Fragment", - "JSONDecodeError", - "OPT_SERIALIZE_NUMPY", - "OPT_SERIALIZE_DATACLASS", - "OPT_SERIALIZE_UUID", - "OPT_NON_STR_KEYS", -] diff --git a/python/langsmith/_internal/_serde.py b/python/langsmith/_internal/_serde.py index 1bf8865c1..e77f7319d 100644 --- a/python/langsmith/_internal/_serde.py +++ b/python/langsmith/_internal/_serde.py @@ -12,7 +12,7 @@ import uuid from typing import Any -from langsmith._internal import _orjson +import orjson try: from zoneinfo import ZoneInfo # type: ignore[import-not-found] @@ -133,13 +133,13 @@ def dumps_json(obj: Any) -> bytes: The JSON formatted string. """ try: - return _orjson.dumps( + return orjson.dumps( obj, default=_serialize_json, - option=_orjson.OPT_SERIALIZE_NUMPY - | _orjson.OPT_SERIALIZE_DATACLASS - | _orjson.OPT_SERIALIZE_UUID - | _orjson.OPT_NON_STR_KEYS, + option=orjson.OPT_SERIALIZE_NUMPY + | orjson.OPT_SERIALIZE_DATACLASS + | orjson.OPT_SERIALIZE_UUID + | orjson.OPT_NON_STR_KEYS, ) except TypeError as e: # Usually caused by UTF surrogate characters @@ -150,9 +150,9 @@ def dumps_json(obj: Any) -> bytes: ensure_ascii=True, ).encode("utf-8") try: - result = _orjson.dumps( - _orjson.loads(result.decode("utf-8", errors="surrogateescape")) + result = orjson.dumps( + orjson.loads(result.decode("utf-8", errors="surrogateescape")) ) - except _orjson.JSONDecodeError: + except orjson.JSONDecodeError: result = _elide_surrogates(result) return result diff --git a/python/langsmith/_testing.py b/python/langsmith/_testing.py index 9eaa0877f..8dd72fbcb 100644 --- a/python/langsmith/_testing.py +++ b/python/langsmith/_testing.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload +import orjson from typing_extensions import TypedDict from langsmith import client as ls_client @@ -20,7 +21,6 @@ from langsmith import run_trees as rt from langsmith import schemas as ls_schemas from langsmith import utils as ls_utils -from langsmith._internal import _orjson try: import pytest # type: ignore @@ -374,7 +374,7 @@ def _serde_example_values(values: VT) -> VT: if values is None: return values bts = ls_client._dumps_json(values) - return _orjson.loads(bts) + return orjson.loads(bts) class _LangSmithTestSuite: diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 8348b57d1..eb397b4c4 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -55,6 +55,7 @@ ) from urllib import parse as urllib_parse +import orjson import requests from requests import adapters as requests_adapters from requests_toolbelt import ( # type: ignore[import-untyped] @@ -68,7 +69,6 @@ from langsmith import env as ls_env from langsmith import schemas as ls_schemas from langsmith import utils as ls_utils -from langsmith._internal import _orjson from langsmith._internal._background_thread import ( TracingQueueItem, ) @@ -368,7 +368,6 @@ class Client: "_info", "_write_api_urls", "_settings", - "_manual_cleanup", ] def __init__( @@ -517,8 +516,6 @@ def __init__( self._settings: Union[ls_schemas.LangSmithSettings, None] = None - self._manual_cleanup = False - def _repr_html_(self) -> str: """Return an HTML representation of the instance with a link to the URL. @@ -1255,7 +1252,7 @@ def _hide_run_inputs(self, inputs: dict): if self._hide_inputs is True: return {} if self._anonymizer: - json_inputs = _orjson.loads(_dumps_json(inputs)) + json_inputs = orjson.loads(_dumps_json(inputs)) return self._anonymizer(json_inputs) if self._hide_inputs is False: return inputs @@ -1265,7 +1262,7 @@ def _hide_run_outputs(self, outputs: dict): if self._hide_outputs is True: return {} if self._anonymizer: - json_outputs = _orjson.loads(_dumps_json(outputs)) + json_outputs = orjson.loads(_dumps_json(outputs)) return self._anonymizer(json_outputs) if self._hide_outputs is False: return outputs @@ -1285,20 +1282,20 @@ def _batch_ingest_run_ops( # form the partial body and ids for op in ops: if isinstance(op, SerializedRunOperation): - curr_dict = _orjson.loads(op._none) + curr_dict = orjson.loads(op._none) if op.inputs: - curr_dict["inputs"] = _orjson.Fragment(op.inputs) + curr_dict["inputs"] = orjson.Fragment(op.inputs) if op.outputs: - curr_dict["outputs"] = _orjson.Fragment(op.outputs) + curr_dict["outputs"] = orjson.Fragment(op.outputs) if op.events: - curr_dict["events"] = _orjson.Fragment(op.events) + curr_dict["events"] = orjson.Fragment(op.events) if op.attachments: logger.warning( "Attachments are not supported when use_multipart_endpoint " "is False" ) ids_and_partial_body[op.operation].append( - (f"trace={op.trace_id},id={op.id}", _orjson.dumps(curr_dict)) + (f"trace={op.trace_id},id={op.id}", orjson.dumps(curr_dict)) ) elif isinstance(op, SerializedFeedbackOperation): logger.warning( @@ -1324,7 +1321,7 @@ def _batch_ingest_run_ops( and body_size + len(body_deque[0][1]) > size_limit_bytes ): self._post_batch_ingest_runs( - _orjson.dumps(body_chunks), + orjson.dumps(body_chunks), _context=f"\n{key}: {'; '.join(context_ids[key])}", ) body_size = 0 @@ -1332,12 +1329,12 @@ def _batch_ingest_run_ops( context_ids.clear() curr_id, curr_body = body_deque.popleft() body_size += len(curr_body) - body_chunks[key].append(_orjson.Fragment(curr_body)) + body_chunks[key].append(orjson.Fragment(curr_body)) context_ids[key].append(curr_id) if body_size: context = "; ".join(f"{k}: {'; '.join(v)}" for k, v in context_ids.items()) self._post_batch_ingest_runs( - _orjson.dumps(body_chunks), _context="\n" + context + orjson.dumps(body_chunks), _context="\n" + context ) def batch_ingest_runs( @@ -2762,7 +2759,7 @@ def create_dataset( "POST", "/datasets", headers={**self._headers, "Content-Type": "application/json"}, - data=_orjson.dumps(dataset), + data=orjson.dumps(dataset), ) ls_utils.raise_for_status_with_text(response) @@ -5678,10 +5675,6 @@ def push_prompt( ) return url - def cleanup(self) -> None: - """Manually trigger cleanup of the background thread.""" - self._manual_cleanup = True - def convert_prompt_to_openai_format( messages: Any, diff --git a/python/poetry.lock b/python/poetry.lock index 2b362f986..a2e1c3667 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -2070,4 +2070,4 @@ vcr = [] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "a5a6c61cba1b5ce9cf739700a780c2df63ff7aaa482c29de9910418263318586" +content-hash = "ca8fa5c9a82d58bea646d5e7e1089175111ddec2c24cd0b19920d1afd4dd93da" diff --git a/python/pyproject.toml b/python/pyproject.toml index 191d61b22..fc1d71da3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -31,7 +31,7 @@ pydantic = [ { version = "^2.7.4", python = ">=3.12.4" }, ] requests = "^2" -orjson = { version = "^3.9.14", markers = "platform_python_implementation != 'PyPy'" } +orjson = "^3.9.14" httpx = ">=0.23.0,<1" requests-toolbelt = "^1.0.0" diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index feec2c2f6..5dc1bbe1e 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -22,6 +22,7 @@ from unittest.mock import MagicMock, patch import dataclasses_json +import orjson import pytest import requests from multipart import MultipartParser, MultipartPart, parse_options_header @@ -32,7 +33,6 @@ import langsmith.utils as ls_utils from langsmith import AsyncClient, EvaluationResult, run_trees from langsmith import schemas as ls_schemas -from langsmith._internal import _orjson from langsmith._internal._serde import _serialize_json from langsmith.client import ( Client, @@ -848,7 +848,7 @@ class MyNamedTuple(NamedTuple): "set_with_class": set([MyClass(1)]), "my_mock": MagicMock(text="Hello, world"), } - res = _orjson.loads(_dumps_json(to_serialize)) + res = orjson.loads(_dumps_json(to_serialize)) assert ( "model_dump" not in caplog.text ), f"Unexpected error logs were emitted: {caplog.text}" @@ -898,7 +898,7 @@ def __repr__(self) -> str: my_cyclic = CyclicClass(other=CyclicClass(other=None)) my_cyclic.other.other = my_cyclic # type: ignore - res = _orjson.loads(_dumps_json({"cyclic": my_cyclic})) + res = orjson.loads(_dumps_json({"cyclic": my_cyclic})) assert res == {"cyclic": "my_cycles..."} expected = {"foo": "foo", "bar": 1} @@ -1142,7 +1142,7 @@ def test_batch_ingest_run_splits_large_batches( op for call in mock_session.request.call_args_list for reqs in ( - _orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else [] + orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else [] ) for op in reqs ] diff --git a/python/tests/unit_tests/test_operations.py b/python/tests/unit_tests/test_operations.py index 43d06ebc5..a6b5cdeb3 100644 --- a/python/tests/unit_tests/test_operations.py +++ b/python/tests/unit_tests/test_operations.py @@ -1,4 +1,5 @@ -from langsmith._internal import _orjson +import orjson + from langsmith._internal._operations import ( SerializedFeedbackOperation, SerializedRunOperation, @@ -13,7 +14,7 @@ def test_combine_serialized_queue_operations(): operation="post", id="id1", trace_id="trace_id1", - _none=_orjson.dumps({"a": 1}), + _none=orjson.dumps({"a": 1}), inputs="inputs1", outputs="outputs1", events="events1", @@ -23,7 +24,7 @@ def test_combine_serialized_queue_operations(): operation="patch", id="id1", trace_id="trace_id1", - _none=_orjson.dumps({"b": "2"}), + _none=orjson.dumps({"b": "2"}), inputs="inputs1-patched", outputs="outputs1-patched", events="events1", @@ -86,7 +87,7 @@ def test_combine_serialized_queue_operations(): operation="post", id="id1", trace_id="trace_id1", - _none=_orjson.dumps({"a": 1, "b": "2"}), + _none=orjson.dumps({"a": 1, "b": "2"}), inputs="inputs1-patched", outputs="outputs1-patched", events="events1",