Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Dec 18, 2023
1 parent 51ff5db commit 7fccec7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
def test_split_ort():
wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"])

splits = wtp.split("This is a test sentence This is another test sentence.")
splits = wtp.split("This is a test sentence This is another test sentence.", threshold=0.005)
assert splits == ["This is a test sentence ", "This is another test sentence."]

def test_split_torch():
wtp = WtP("benjamin/wtp-bert-mini", hub_prefix=None)

splits = wtp.split("This is a test sentence This is another test sentence.")
splits = wtp.split("This is a test sentence This is another test sentence.", threshold=0.005)
assert splits == ["This is a test sentence ", "This is another test sentence."]

def test_split_torch_canine():
Expand All @@ -27,7 +27,7 @@ def test_move_device():
def test_strip_whitespace():
wtp = WtP("benjamin/wtp-bert-mini", hub_prefix=None)

splits = wtp.split("This is a test sentence This is another test sentence. ", strip_whitespace=True)
splits = wtp.split("This is a test sentence This is another test sentence. ", strip_whitespace=True, threshold=0.005)
assert splits == ["This is a test sentence", "This is another test sentence."]

def test_split_long():
Expand All @@ -36,7 +36,7 @@ def test_split_long():
wtp = WtP("benjamin/wtp-bert-mini", hub_prefix=None)

splits = wtp.split(prefix + " This is a test sentence. This is another test sentence.")
assert splits == [prefix + " ", "This is a test sentence. ", "This is another test sentence."]
assert splits == [prefix + " " + "This is a test sentence. ", "This is another test sentence."]


def test_split_batched():
Expand Down Expand Up @@ -117,6 +117,6 @@ def test_split_threshold():
)
assert splits == ["This is a test sentence. This is another test sentence."]

splits = wtp.split("This is a test sentence. This is another test sentence.", threshold=0.0)
splits = wtp.split("This is a test sentence. This is another test sentence.", threshold=-1e-3)
# space might still be included in a character split
assert splits[:3] == list("Thi")
2 changes: 1 addition & 1 deletion wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __call__(self, hashed_ids, attention_mask):
logits = self.ort_session.run(
["logits"],
{
"attention_mask": attention_mask,
"attention_mask": attention_mask.astype(np.float16), # ORT expects fp16 mask
"hashed_ids": hashed_ids,
},
)[0]
Expand Down

0 comments on commit 7fccec7

Please sign in to comment.