Skip to content

Commit

Permalink
fix: Missing Nvidia embedding truncate mode (#1043)
Browse files Browse the repository at this point in the history
* fix: Add NONE option to EmbeddingTruncateMode

* refactor: Validate input with _missing_ method

* test: Add EmbeddingTruncateMode test

* refactor: Revert "refactor: Validate input with _missing_ method"

This reverts commit 8334a50.

* Update integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py

---------

Co-authored-by: Madeesh Kannan <[email protected]>
  • Loading branch information
kanenorman and shadeMe authored Sep 2, 2024
1 parent 18bb61d commit 129ba54
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ class EmbeddingTruncateMode(Enum):
Specifies how inputs to the NVIDIA embedding components are truncated.
If START, the input will be truncated from the start.
If END, the input will be truncated from the end.
If NONE, an error will be returned (if the input is too long).
"""

START = "START"
END = "END"
NONE = "NONE"

def __str__(self):
return self.value
Expand Down
40 changes: 40 additions & 0 deletions integrations/nvidia/tests/test_embedding_truncate_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest

from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode


class TestEmbeddingTruncateMode:
@pytest.mark.parametrize(
"mode, expected",
[
("START", EmbeddingTruncateMode.START),
("END", EmbeddingTruncateMode.END),
("NONE", EmbeddingTruncateMode.NONE),
(EmbeddingTruncateMode.START, EmbeddingTruncateMode.START),
(EmbeddingTruncateMode.END, EmbeddingTruncateMode.END),
(EmbeddingTruncateMode.NONE, EmbeddingTruncateMode.NONE),
],
)
def test_init_with_valid_mode(self, mode, expected):
assert EmbeddingTruncateMode(mode) == expected

def test_init_with_invalid_mode_raises_value_error(self):
with pytest.raises(ValueError):
invalid_mode = "INVALID"
EmbeddingTruncateMode(invalid_mode)

@pytest.mark.parametrize(
"mode, expected",
[
("START", EmbeddingTruncateMode.START),
("END", EmbeddingTruncateMode.END),
("NONE", EmbeddingTruncateMode.NONE),
],
)
def test_from_str_with_valid_mode(self, mode, expected):
assert EmbeddingTruncateMode.from_str(mode) == expected

def test_from_str_with_invalid_mode_raises_value_error(self):
with pytest.raises(ValueError):
invalid_mode = "INVALID"
EmbeddingTruncateMode.from_str(invalid_mode)

0 comments on commit 129ba54

Please sign in to comment.