From 129ba54528d950c445f57cda8ef2692ae6d70d90 Mon Sep 17 00:00:00 2001 From: Kane Norman <51185594+kanenorman@users.noreply.github.com> Date: Mon, 2 Sep 2024 08:53:52 -0500 Subject: [PATCH] fix: Missing Nvidia embedding truncate mode (#1043) * fix: Add NONE option to EmbeddingTruncateMode * refactor: Validate input with _missing_ method * test: Add EmbeddingTruncateMode test * refactor: Revert "refactor: Validate input with _missing_ method" This reverts commit 8334a50998db234076dc6ae3a7fe60525d5ef437. * Update integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py --------- Co-authored-by: Madeesh Kannan --- .../components/embedders/nvidia/truncate.py | 2 + .../tests/test_embedding_truncate_mode.py | 40 +++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 integrations/nvidia/tests/test_embedding_truncate_mode.py diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py index 2c32eabb1..3a8eb9d07 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py @@ -6,10 +6,12 @@ class EmbeddingTruncateMode(Enum): Specifies how inputs to the NVIDIA embedding components are truncated. If START, the input will be truncated from the start. If END, the input will be truncated from the end. + If NONE, an error will be returned (if the input is too long). """ START = "START" END = "END" + NONE = "NONE" def __str__(self): return self.value diff --git a/integrations/nvidia/tests/test_embedding_truncate_mode.py b/integrations/nvidia/tests/test_embedding_truncate_mode.py new file mode 100644 index 000000000..e74d0308c --- /dev/null +++ b/integrations/nvidia/tests/test_embedding_truncate_mode.py @@ -0,0 +1,40 @@ +import pytest + +from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode + + +class TestEmbeddingTruncateMode: + @pytest.mark.parametrize( + "mode, expected", + [ + ("START", EmbeddingTruncateMode.START), + ("END", EmbeddingTruncateMode.END), + ("NONE", EmbeddingTruncateMode.NONE), + (EmbeddingTruncateMode.START, EmbeddingTruncateMode.START), + (EmbeddingTruncateMode.END, EmbeddingTruncateMode.END), + (EmbeddingTruncateMode.NONE, EmbeddingTruncateMode.NONE), + ], + ) + def test_init_with_valid_mode(self, mode, expected): + assert EmbeddingTruncateMode(mode) == expected + + def test_init_with_invalid_mode_raises_value_error(self): + with pytest.raises(ValueError): + invalid_mode = "INVALID" + EmbeddingTruncateMode(invalid_mode) + + @pytest.mark.parametrize( + "mode, expected", + [ + ("START", EmbeddingTruncateMode.START), + ("END", EmbeddingTruncateMode.END), + ("NONE", EmbeddingTruncateMode.NONE), + ], + ) + def test_from_str_with_valid_mode(self, mode, expected): + assert EmbeddingTruncateMode.from_str(mode) == expected + + def test_from_str_with_invalid_mode_raises_value_error(self): + with pytest.raises(ValueError): + invalid_mode = "INVALID" + EmbeddingTruncateMode.from_str(invalid_mode)