Skip to content

Commit

Permalink
Add additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awinml committed Feb 21, 2024
1 parent de9bd77 commit 096b1d3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_str(cls, string: str) -> "PoolingMode":
"pooling_mode_max_tokens": PoolingMode.MAX,
"pooling_mode_mean_sqrt_len_tokens": PoolingMode.MEAN_SQRT_LEN,
"pooling_mode_weightedmean_tokens": PoolingMode.WEIGHTED_MEAN,
"pooling_mode_last_token": PoolingMode.LAST_TOKEN,
"pooling_mode_lasttoken": PoolingMode.LAST_TOKEN,
}

INVERSE_POOLING_MODES_MAP = {mode: name for name, mode in POOLING_MODES_MAP.items()}
Expand Down Expand Up @@ -120,6 +120,12 @@ def pool_embeddings(self) -> torch.tensor:
pooling_func_map = {
INVERSE_POOLING_MODES_MAP[self.pooling_mode]: True,
}
# By default, sentence-transformers uses mean pooling
# If multiple pooling methods are specified, the output dimension of the embeddings is scaled by the number of
# pooling methods selected
if self.pooling_mode != PoolingMode.MEAN:
pooling_func_map[INVERSE_POOLING_MODES_MAP[PoolingMode.MEAN]] = False

# First element of model_output contains all token embeddings
token_embeddings = self.model_output[0]
word_embedding_dimension = token_embeddings.size(dim=2)
Expand Down
26 changes: 17 additions & 9 deletions integrations/optimum/tests/test_optimum_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@ def backend():
return backend


def test_embed_output_order(backend):
texts_to_embed = ["short text", "text that is longer than the other", "medium length text"]
embeddings = backend.embed(texts_to_embed, normalize_embeddings=False, pooling_mode=PoolingMode.MEAN)
class TestOptimumBackend:
def test_embed_output_order(self, backend):
texts_to_embed = ["short text", "text that is longer than the other", "medium length text"]
embeddings = backend.embed(texts_to_embed, normalize_embeddings=False, pooling_mode=PoolingMode.MEAN)

# Compute individual embeddings in order
expected_embeddings = []
for text in texts_to_embed:
expected_embeddings.append(backend.embed(text, normalize_embeddings=False, pooling_mode=PoolingMode.MEAN))
# Compute individual embeddings in order
expected_embeddings = []
for text in texts_to_embed:
expected_embeddings.append(backend.embed(text, normalize_embeddings=False, pooling_mode=PoolingMode.MEAN))

# Assert that the embeddings are in the same order
assert embeddings == expected_embeddings
# Assert that the embeddings are in the same order
assert embeddings == expected_embeddings

def test_run_pooling_modes(self, backend):
for pooling_mode in PoolingMode:
embedding = backend.embed("test text", normalize_embeddings=False, pooling_mode=pooling_mode)

assert len(embedding) == 768
assert all(isinstance(x, float) for x in embedding)

0 comments on commit 096b1d3

Please sign in to comment.