From 54ff955f874bf3922ab878f55054ee41a8b95d57 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 23 Feb 2024 15:58:38 -0800 Subject: [PATCH] Orjson non str key fix (#471) --- python/langsmith/client.py | 52 ++++++++----------- python/tests/integration_tests/test_client.py | 2 +- python/tests/unit_tests/test_client.py | 16 +++--- 3 files changed, 32 insertions(+), 38 deletions(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index a4506dcaf..0585a7397 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -3,7 +3,6 @@ from __future__ import annotations import collections -import dataclasses import datetime import functools import importlib @@ -156,7 +155,6 @@ def _default_retry_config() -> Retry: return ls_utils.LangSmithRetry(**retry_params) # type: ignore -_PRIMITIVE_TYPES = (str, int, float, bool) _MAX_DEPTH = 2 @@ -164,21 +162,13 @@ def _serialize_json(obj: Any, depth: int = 0) -> Any: try: if depth >= _MAX_DEPTH: try: - return orjson.loads(orjson.dumps(obj)) + return orjson.loads(_dumps_json_single(obj)) except BaseException: return repr(obj) - if isinstance(obj, datetime.datetime): - return obj.isoformat() - if isinstance(obj, uuid.UUID): - return str(obj) - if obj is None or isinstance(obj, _PRIMITIVE_TYPES): - return obj if isinstance(obj, bytes): return obj.decode("utf-8") - if isinstance(obj, (set, list, tuple)): - return [_serialize_json(x, depth + 1) for x in list(obj)] - if isinstance(obj, dict): - return {k: _serialize_json(v, depth + 1) for k, v in obj.items()} + if isinstance(obj, (set, tuple)): + return orjson.loads(_dumps_json_single(list(obj))) serialization_methods = [ ("model_dump_json", True), # Pydantic V2 @@ -197,26 +187,33 @@ def _serialize_json(obj: Any, depth: int = 0) -> Any: except Exception as e: logger.debug(f"Failed to serialize {type(obj)} to JSON: {e}") return repr(obj) - - if dataclasses.is_dataclass(obj): - # Regular dataclass - return dataclasses.asdict(obj) if hasattr(obj, "__slots__"): all_attrs = {slot: getattr(obj, slot, None) for slot in obj.__slots__} elif hasattr(obj, "__dict__"): all_attrs = vars(obj) else: return repr(obj) - return { - k: _serialize_json(v, depth=depth + 1) if v is not obj else repr(v) - for k, v in all_attrs.items() - } + filtered = {k: v if v is not obj else repr(v) for k, v in all_attrs.items()} + return orjson.loads(_dumps_json(filtered, depth=depth + 1)) except BaseException as e: logger.debug(f"Failed to serialize {type(obj)} to JSON: {e}") return repr(obj) -def _dumps_json(obj: Any) -> bytes: +def _dumps_json_single( + obj: Any, default: Optional[Callable[[Any], Any]] = None +) -> bytes: + return orjson.dumps( + obj, + default=default, + option=orjson.OPT_SERIALIZE_NUMPY + | orjson.OPT_SERIALIZE_DATACLASS + | orjson.OPT_SERIALIZE_UUID + | orjson.OPT_NON_STR_KEYS, + ) + + +def _dumps_json(obj: Any, depth: int = 0) -> bytes: """Serialize an object to a JSON formatted string. Parameters @@ -231,13 +228,7 @@ def _dumps_json(obj: Any) -> bytes: str The JSON formatted string. """ - return orjson.dumps( - obj, - default=_serialize_json, - option=orjson.OPT_SERIALIZE_NUMPY - | orjson.OPT_SERIALIZE_DATACLASS - | orjson.OPT_SERIALIZE_UUID, - ) + return _dumps_json_single(obj, functools.partial(_serialize_json, depth=depth)) def close_session(session: requests.Session) -> None: @@ -1009,7 +1000,6 @@ def create_run( } if not self._filter_for_sampling([run_create]): return - run_create = self._run_transform(run_create) self._insert_runtime_env([run_create]) @@ -1032,7 +1022,6 @@ def create_run( "Accept": "application/json", "Content-Type": "application/json", } - self.request_with_retries( "post", f"{self.api_url}/runs", @@ -3517,6 +3506,7 @@ def _tracing_thread_handle_batch( try: client.batch_ingest_runs(create=create, update=update, pre_sampled=True) except Exception: + logger.error("Error in tracing queue", exc_info=True) # exceptions are logged elsewhere, but we need to make sure the # background thread continues to run pass diff --git a/python/tests/integration_tests/test_client.py b/python/tests/integration_tests/test_client.py index 52722ce05..94b03e728 100644 --- a/python/tests/integration_tests/test_client.py +++ b/python/tests/integration_tests/test_client.py @@ -359,7 +359,6 @@ def test_create_chat_example( langchain_client.delete_dataset(dataset_id=dataset.id) -@freeze_time("2023-01-01") def test_batch_ingest_runs(langchain_client: Client) -> None: _session = "__test_batch_ingest_runs" trace_id = uuid4() @@ -418,6 +417,7 @@ def test_batch_ingest_runs(langchain_client: Client) -> None: raise LangSmithError("Runs not created yet") except LangSmithError: time.sleep(wait) + wait += 1 else: raise ValueError("Runs not created in time") assert len(runs) == 2 diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index c25466517..1e62c1e88 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -20,6 +20,7 @@ import attr import dataclasses_json +import orjson import pytest import requests from pydantic import BaseModel @@ -29,6 +30,7 @@ import langsmith.utils as ls_utils from langsmith.client import ( Client, + _dumps_json, _get_api_key, _get_api_url, _is_langchain_hosted, @@ -577,8 +579,7 @@ class MyNamedTuple(NamedTuple): "cyclic": CyclicClass(), "cyclic2": cycle_2, } - - res = json.loads(json.dumps(to_serialize, default=_serialize_json)) + res = orjson.loads(_dumps_json(to_serialize)) expected = { "uid": str(uid), "time": current_time.isoformat(), @@ -616,10 +617,13 @@ class MyNamedTuple(NamedTuple): } assert set(expected) == set(res) for k, v in expected.items(): - if callable(v): - assert v(res[k]) - else: - assert res[k] == v + try: + if callable(v): + assert v(res[k]), f"Failed for {k}" + else: + assert res[k] == v, f"Failed for {k}" + except AssertionError: + raise @patch("langsmith.client.requests.Session", autospec=True)