Skip to content

Commit

Permalink
sdk patch: raise errors on update/create examples when passing mismat…
Browse files Browse the repository at this point in the history
…ching length sequences (#1238)
  • Loading branch information
isahers1 authored Nov 20, 2024
1 parent a07d3d6 commit b95f6d2
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
31 changes: 31 additions & 0 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3413,6 +3413,22 @@ def create_examples(

if dataset_id is None:
dataset_id = self.read_dataset(dataset_name=dataset_name).id

sequence_args = {
"outputs": outputs,
"metadata": metadata,
"splits": splits,
"ids": ids,
"source_run_ids": source_run_ids,
}
# Since inputs are required, we will check against them
input_len = len(inputs)
for arg_name, arg_value in sequence_args.items():
if arg_value is not None and len(arg_value) != input_len:
raise ValueError(
f"Length of {arg_name} ({len(arg_value)}) does not match"
f" length of inputs ({input_len})"
)
examples = [
{
"inputs": in_,
Expand Down Expand Up @@ -3816,6 +3832,21 @@ def update_examples(
Dict[str, Any]
The response from the server (specifies the number of examples updated).
"""
sequence_args = {
"inputs": inputs,
"outputs": outputs,
"metadata": metadata,
"splits": splits,
"dataset_ids": dataset_ids,
}
# Since inputs are required, we will check against them
examples_len = len(example_ids)
for arg_name, arg_value in sequence_args.items():
if arg_value is not None and len(arg_value) != examples_len:
raise ValueError(
f"Length of {arg_name} ({len(arg_value)}) does not match"
f" length of examples ({examples_len})"
)
examples = [
{
"id": id_,
Expand Down
41 changes: 41 additions & 0 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,3 +1018,44 @@ def create_encoder(*args, **kwargs):
myobj["key_1"]

assert not caplog.records


def test_examples_length_validation(langchain_client: Client) -> None:
"""Test that mismatched lengths raise ValueError for create and update examples."""
dataset_name = "__test_examples_length_validation" + uuid4().hex[:4]
dataset = langchain_client.create_dataset(dataset_name=dataset_name)

# Test create_examples validation
inputs = [{"text": "hello"}, {"text": "world"}]
outputs = [{"response": "hi"}] # One less than inputs
with pytest.raises(ValueError) as exc_info:
langchain_client.create_examples(
inputs=inputs, outputs=outputs, dataset_id=dataset.id
)
assert "Length of outputs (1) does not match length of inputs (2)" in str(
exc_info.value
)

# Create some valid examples for testing update
langchain_client.create_examples(
inputs=[{"text": "hello"}, {"text": "world"}],
outputs=[{"response": "hi"}, {"response": "earth"}],
dataset_id=dataset.id,
)
example_ids = [
example.id for example in langchain_client.list_examples(dataset_id=dataset.id)
]

# Test update_examples validation
with pytest.raises(ValueError) as exc_info:
langchain_client.update_examples(
example_ids=example_ids,
inputs=[{"text": "new hello"}], # One less than example_ids
outputs=[{"response": "new hi"}, {"response": "new earth"}],
)
assert "Length of inputs (1) does not match length of examples (2)" in str(
exc_info.value
)

# Clean up
langchain_client.delete_dataset(dataset_id=dataset.id)

0 comments on commit b95f6d2

Please sign in to comment.