From b95f6d2ca1bb364cd5de4c924fc47db945817d2b Mon Sep 17 00:00:00 2001 From: Isaac Francisco <78627776+isahers1@users.noreply.github.com> Date: Wed, 20 Nov 2024 12:04:05 -0800 Subject: [PATCH] sdk patch: raise errors on update/create examples when passing mismatching length sequences (#1238) --- python/langsmith/client.py | 31 ++++++++++++++ python/tests/integration_tests/test_client.py | 41 +++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 5e906507d..1647b790d 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -3413,6 +3413,22 @@ def create_examples( if dataset_id is None: dataset_id = self.read_dataset(dataset_name=dataset_name).id + + sequence_args = { + "outputs": outputs, + "metadata": metadata, + "splits": splits, + "ids": ids, + "source_run_ids": source_run_ids, + } + # Since inputs are required, we will check against them + input_len = len(inputs) + for arg_name, arg_value in sequence_args.items(): + if arg_value is not None and len(arg_value) != input_len: + raise ValueError( + f"Length of {arg_name} ({len(arg_value)}) does not match" + f" length of inputs ({input_len})" + ) examples = [ { "inputs": in_, @@ -3816,6 +3832,21 @@ def update_examples( Dict[str, Any] The response from the server (specifies the number of examples updated). """ + sequence_args = { + "inputs": inputs, + "outputs": outputs, + "metadata": metadata, + "splits": splits, + "dataset_ids": dataset_ids, + } + # Since inputs are required, we will check against them + examples_len = len(example_ids) + for arg_name, arg_value in sequence_args.items(): + if arg_value is not None and len(arg_value) != examples_len: + raise ValueError( + f"Length of {arg_name} ({len(arg_value)}) does not match" + f" length of examples ({examples_len})" + ) examples = [ { "id": id_, diff --git a/python/tests/integration_tests/test_client.py b/python/tests/integration_tests/test_client.py index 57a6e2171..9bea700cd 100644 --- a/python/tests/integration_tests/test_client.py +++ b/python/tests/integration_tests/test_client.py @@ -1018,3 +1018,44 @@ def create_encoder(*args, **kwargs): myobj["key_1"] assert not caplog.records + + +def test_examples_length_validation(langchain_client: Client) -> None: + """Test that mismatched lengths raise ValueError for create and update examples.""" + dataset_name = "__test_examples_length_validation" + uuid4().hex[:4] + dataset = langchain_client.create_dataset(dataset_name=dataset_name) + + # Test create_examples validation + inputs = [{"text": "hello"}, {"text": "world"}] + outputs = [{"response": "hi"}] # One less than inputs + with pytest.raises(ValueError) as exc_info: + langchain_client.create_examples( + inputs=inputs, outputs=outputs, dataset_id=dataset.id + ) + assert "Length of outputs (1) does not match length of inputs (2)" in str( + exc_info.value + ) + + # Create some valid examples for testing update + langchain_client.create_examples( + inputs=[{"text": "hello"}, {"text": "world"}], + outputs=[{"response": "hi"}, {"response": "earth"}], + dataset_id=dataset.id, + ) + example_ids = [ + example.id for example in langchain_client.list_examples(dataset_id=dataset.id) + ] + + # Test update_examples validation + with pytest.raises(ValueError) as exc_info: + langchain_client.update_examples( + example_ids=example_ids, + inputs=[{"text": "new hello"}], # One less than example_ids + outputs=[{"response": "new hi"}, {"response": "new earth"}], + ) + assert "Length of inputs (1) does not match length of examples (2)" in str( + exc_info.value + ) + + # Clean up + langchain_client.delete_dataset(dataset_id=dataset.id)