Skip to content

Commit

Permalink
Orjson non str key fix (#471)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Feb 23, 2024
1 parent ca944c6 commit 54ff955
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 38 deletions.
52 changes: 21 additions & 31 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import collections
import dataclasses
import datetime
import functools
import importlib
Expand Down Expand Up @@ -156,29 +155,20 @@ def _default_retry_config() -> Retry:
return ls_utils.LangSmithRetry(**retry_params) # type: ignore


_PRIMITIVE_TYPES = (str, int, float, bool)
_MAX_DEPTH = 2


def _serialize_json(obj: Any, depth: int = 0) -> Any:
try:
if depth >= _MAX_DEPTH:
try:
return orjson.loads(orjson.dumps(obj))
return orjson.loads(_dumps_json_single(obj))
except BaseException:
return repr(obj)
if isinstance(obj, datetime.datetime):
return obj.isoformat()
if isinstance(obj, uuid.UUID):
return str(obj)
if obj is None or isinstance(obj, _PRIMITIVE_TYPES):
return obj
if isinstance(obj, bytes):
return obj.decode("utf-8")
if isinstance(obj, (set, list, tuple)):
return [_serialize_json(x, depth + 1) for x in list(obj)]
if isinstance(obj, dict):
return {k: _serialize_json(v, depth + 1) for k, v in obj.items()}
if isinstance(obj, (set, tuple)):
return orjson.loads(_dumps_json_single(list(obj)))

serialization_methods = [
("model_dump_json", True), # Pydantic V2
Expand All @@ -197,26 +187,33 @@ def _serialize_json(obj: Any, depth: int = 0) -> Any:
except Exception as e:
logger.debug(f"Failed to serialize {type(obj)} to JSON: {e}")
return repr(obj)

if dataclasses.is_dataclass(obj):
# Regular dataclass
return dataclasses.asdict(obj)
if hasattr(obj, "__slots__"):
all_attrs = {slot: getattr(obj, slot, None) for slot in obj.__slots__}
elif hasattr(obj, "__dict__"):
all_attrs = vars(obj)
else:
return repr(obj)
return {
k: _serialize_json(v, depth=depth + 1) if v is not obj else repr(v)
for k, v in all_attrs.items()
}
filtered = {k: v if v is not obj else repr(v) for k, v in all_attrs.items()}
return orjson.loads(_dumps_json(filtered, depth=depth + 1))
except BaseException as e:
logger.debug(f"Failed to serialize {type(obj)} to JSON: {e}")
return repr(obj)


def _dumps_json(obj: Any) -> bytes:
def _dumps_json_single(
obj: Any, default: Optional[Callable[[Any], Any]] = None
) -> bytes:
return orjson.dumps(
obj,
default=default,
option=orjson.OPT_SERIALIZE_NUMPY
| orjson.OPT_SERIALIZE_DATACLASS
| orjson.OPT_SERIALIZE_UUID
| orjson.OPT_NON_STR_KEYS,
)


def _dumps_json(obj: Any, depth: int = 0) -> bytes:
"""Serialize an object to a JSON formatted string.
Parameters
Expand All @@ -231,13 +228,7 @@ def _dumps_json(obj: Any) -> bytes:
str
The JSON formatted string.
"""
return orjson.dumps(
obj,
default=_serialize_json,
option=orjson.OPT_SERIALIZE_NUMPY
| orjson.OPT_SERIALIZE_DATACLASS
| orjson.OPT_SERIALIZE_UUID,
)
return _dumps_json_single(obj, functools.partial(_serialize_json, depth=depth))


def close_session(session: requests.Session) -> None:
Expand Down Expand Up @@ -1009,7 +1000,6 @@ def create_run(
}
if not self._filter_for_sampling([run_create]):
return

run_create = self._run_transform(run_create)
self._insert_runtime_env([run_create])

Expand All @@ -1032,7 +1022,6 @@ def create_run(
"Accept": "application/json",
"Content-Type": "application/json",
}

self.request_with_retries(
"post",
f"{self.api_url}/runs",
Expand Down Expand Up @@ -3517,6 +3506,7 @@ def _tracing_thread_handle_batch(
try:
client.batch_ingest_runs(create=create, update=update, pre_sampled=True)
except Exception:
logger.error("Error in tracing queue", exc_info=True)
# exceptions are logged elsewhere, but we need to make sure the
# background thread continues to run
pass
Expand Down
2 changes: 1 addition & 1 deletion python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ def test_create_chat_example(
langchain_client.delete_dataset(dataset_id=dataset.id)


@freeze_time("2023-01-01")
def test_batch_ingest_runs(langchain_client: Client) -> None:
_session = "__test_batch_ingest_runs"
trace_id = uuid4()
Expand Down Expand Up @@ -418,6 +417,7 @@ def test_batch_ingest_runs(langchain_client: Client) -> None:
raise LangSmithError("Runs not created yet")
except LangSmithError:
time.sleep(wait)
wait += 1
else:
raise ValueError("Runs not created in time")
assert len(runs) == 2
Expand Down
16 changes: 10 additions & 6 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import attr
import dataclasses_json
import orjson
import pytest
import requests
from pydantic import BaseModel
Expand All @@ -29,6 +30,7 @@
import langsmith.utils as ls_utils
from langsmith.client import (
Client,
_dumps_json,
_get_api_key,
_get_api_url,
_is_langchain_hosted,
Expand Down Expand Up @@ -577,8 +579,7 @@ class MyNamedTuple(NamedTuple):
"cyclic": CyclicClass(),
"cyclic2": cycle_2,
}

res = json.loads(json.dumps(to_serialize, default=_serialize_json))
res = orjson.loads(_dumps_json(to_serialize))
expected = {
"uid": str(uid),
"time": current_time.isoformat(),
Expand Down Expand Up @@ -616,10 +617,13 @@ class MyNamedTuple(NamedTuple):
}
assert set(expected) == set(res)
for k, v in expected.items():
if callable(v):
assert v(res[k])
else:
assert res[k] == v
try:
if callable(v):
assert v(res[k]), f"Failed for {k}"
else:
assert res[k] == v, f"Failed for {k}"
except AssertionError:
raise


@patch("langsmith.client.requests.Session", autospec=True)
Expand Down

0 comments on commit 54ff955

Please sign in to comment.