Skip to content

Commit

Permalink
Some craziness
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Sep 27, 2024
1 parent 86aab86 commit eeec25d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 88 deletions.
36 changes: 16 additions & 20 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
)
from urllib import parse as urllib_parse

import orjson
import msgspec
import requests
from requests import adapters as requests_adapters
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -182,13 +182,13 @@ def _serialize_json(obj: Any, depth: int = 0, serialize_py: bool = True) -> Any:
try:
if depth >= _MAX_DEPTH:
try:
return orjson.loads(_dumps_json_single(obj))
return msgspec.json.decode(_dumps_json_single(obj))
except BaseException:
return repr(obj)
if isinstance(obj, bytes):
return obj.decode("utf-8")
if isinstance(obj, (set, tuple)):
return orjson.loads(_dumps_json_single(list(obj)))
return msgspec.json.decode(_dumps_json_single(list(obj)))

serialization_methods = [
("model_dump_json", True), # Pydantic V2
Expand All @@ -206,7 +206,7 @@ def _serialize_json(obj: Any, depth: int = 0, serialize_py: bool = True) -> Any:
)
if isinstance(json_str, str):
return json.loads(json_str)
return orjson.loads(
return msgspec.json.decode(
_dumps_json(
json_str, depth=depth + 1, serialize_py=serialize_py
)
Expand All @@ -226,7 +226,7 @@ def _serialize_json(obj: Any, depth: int = 0, serialize_py: bool = True) -> Any:
filtered = {
k: v if v is not obj else repr(v) for k, v in all_attrs.items()
}
return orjson.loads(
return msgspec.json.decode(
_dumps_json(filtered, depth=depth + 1, serialize_py=serialize_py)
)
return repr(obj)
Expand All @@ -245,13 +245,9 @@ def _dumps_json_single(
obj: Any, default: Optional[Callable[[Any], Any]] = None
) -> bytes:
try:
return orjson.dumps(
return msgspec.json.encode(
obj,
default=default,
option=orjson.OPT_SERIALIZE_NUMPY
| orjson.OPT_SERIALIZE_DATACLASS
| orjson.OPT_SERIALIZE_UUID
| orjson.OPT_NON_STR_KEYS,
enc_hook=default,
)
except TypeError as e:
# Usually caused by UTF surrogate characters
Expand All @@ -262,10 +258,10 @@ def _dumps_json_single(
ensure_ascii=True,
).encode("utf-8")
try:
result = orjson.dumps(
orjson.loads(result.decode("utf-8", errors="surrogateescape"))
result = msgspec.json.encode(
msgspec.json.decode(result.decode("utf-8", errors="surrogateescape"))
)
except orjson.JSONDecodeError:
except msgspec.DecodeError:
result = _elide_surrogates(result)
return result

Expand Down Expand Up @@ -1307,7 +1303,7 @@ def _hide_run_inputs(self, inputs: dict):
if self._hide_inputs is True:
return {}
if self._anonymizer:
json_inputs = orjson.loads(_dumps_json(inputs))
json_inputs = msgspec.json.decode(_dumps_json(inputs))
return self._anonymizer(json_inputs)
if self._hide_inputs is False:
return inputs
Expand All @@ -1317,7 +1313,7 @@ def _hide_run_outputs(self, outputs: dict):
if self._hide_outputs is True:
return {}
if self._anonymizer:
json_outputs = orjson.loads(_dumps_json(outputs))
json_outputs = msgspec.json.decode(_dumps_json(outputs))
return self._anonymizer(json_outputs)
if self._hide_outputs is False:
return outputs
Expand Down Expand Up @@ -1430,19 +1426,19 @@ def batch_ingest_runs(
while body:
if body_size > 0 and body_size + len(body[0]) > size_limit_bytes:
self._post_batch_ingest_runs(
orjson.dumps(body_chunks),
msgspec.json.encode(body_chunks),
_context=f"\n{key}: {'; '.join(context_ids[key])}",
)
body_size = 0
body_chunks.clear()
context_ids.clear()
body_size += len(body[0])
body_chunks[key].append(orjson.Fragment(body.popleft()))
body_chunks[key].append(msgspec.Raw(body.popleft()))
context_ids[key].append(ids_.popleft())
if body_size:
context = "; ".join(f"{k}: {'; '.join(v)}" for k, v in context_ids.items())
self._post_batch_ingest_runs(
orjson.dumps(body_chunks), _context="\n" + context
msgspec.json.encode(body_chunks), _context="\n" + context
)

def _post_batch_ingest_runs(self, body: bytes, *, _context: str):
Expand Down Expand Up @@ -2616,7 +2612,7 @@ def create_dataset(
"POST",
"/datasets",
headers={**self._headers, "Content-Type": "application/json"},
data=orjson.dumps(dataset),
data=msgspec.json.encode(dataset),
)
ls_utils.raise_for_status_with_text(response)

Expand Down
Loading

0 comments on commit eeec25d

Please sign in to comment.