From 7fccec775044a12879731572e45e07a5d18d33f9 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Mon, 18 Dec 2023 17:08:29 +0000 Subject: [PATCH] fix tests --- test.py | 10 +++++----- wtpsplit/extract.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test.py b/test.py index 21047399..64fc0c50 100644 --- a/test.py +++ b/test.py @@ -5,13 +5,13 @@ def test_split_ort(): wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) - splits = wtp.split("This is a test sentence This is another test sentence.") + splits = wtp.split("This is a test sentence This is another test sentence.", threshold=0.005) assert splits == ["This is a test sentence ", "This is another test sentence."] def test_split_torch(): wtp = WtP("benjamin/wtp-bert-mini", hub_prefix=None) - splits = wtp.split("This is a test sentence This is another test sentence.") + splits = wtp.split("This is a test sentence This is another test sentence.", threshold=0.005) assert splits == ["This is a test sentence ", "This is another test sentence."] def test_split_torch_canine(): @@ -27,7 +27,7 @@ def test_move_device(): def test_strip_whitespace(): wtp = WtP("benjamin/wtp-bert-mini", hub_prefix=None) - splits = wtp.split("This is a test sentence This is another test sentence. ", strip_whitespace=True) + splits = wtp.split("This is a test sentence This is another test sentence. ", strip_whitespace=True, threshold=0.005) assert splits == ["This is a test sentence", "This is another test sentence."] def test_split_long(): @@ -36,7 +36,7 @@ def test_split_long(): wtp = WtP("benjamin/wtp-bert-mini", hub_prefix=None) splits = wtp.split(prefix + " This is a test sentence. This is another test sentence.") - assert splits == [prefix + " ", "This is a test sentence. ", "This is another test sentence."] + assert splits == [prefix + " " + "This is a test sentence. ", "This is another test sentence."] def test_split_batched(): @@ -117,6 +117,6 @@ def test_split_threshold(): ) assert splits == ["This is a test sentence. This is another test sentence."] - splits = wtp.split("This is a test sentence. This is another test sentence.", threshold=0.0) + splits = wtp.split("This is a test sentence. This is another test sentence.", threshold=-1e-3) # space might still be included in a character split assert splits[:3] == list("Thi") diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 1466d79a..98a544f5 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -20,7 +20,7 @@ def __call__(self, hashed_ids, attention_mask): logits = self.ort_session.run( ["logits"], { - "attention_mask": attention_mask, + "attention_mask": attention_mask.astype(np.float16), # ORT expects fp16 mask "hashed_ids": hashed_ids, }, )[0]