Skip to content

Commit

Permalink
Fix timestamps tests (NVIDIA#11053)
Browse files Browse the repository at this point in the history
* change timestamps tests

Signed-off-by: Monica Sekoyan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: monica-sekoyan <[email protected]>

---------

Signed-off-by: Monica Sekoyan <[email protected]>
Signed-off-by: monica-sekoyan <[email protected]>
Co-authored-by: monica-sekoyan <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
2 people authored and Hainan Xu committed Nov 5, 2024
1 parent e09c9f7 commit 1403a6b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
20 changes: 14 additions & 6 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


def char_vocabulary():
return [' ', 'a', 'b', 'c', 'd', 'e', 'f']
return [' ', 'a', 'b', 'c', 'd', 'e', 'f', '.']


@pytest.fixture()
Expand Down Expand Up @@ -60,11 +60,19 @@ def check_char_timestamps(hyp: Hypothesis, decoding: CTCDecoding):
words = list(filter(lambda x: x != '', words))
assert len(hyp.timestep['word']) == len(words)

segments_count = sum([hyp.text.count(seperator) for seperator in decoding.segment_seperators])
if hyp.text[-1] not in decoding.segment_seperators:
segments_count += 1
segments = []
segment = []

assert len(hyp.timestep['segment']) == segments_count
for word in words:
segment.append(word)
if word[-1] in decoding.segment_seperators:
segments.append(' '.join(segment))
segment = []

if segment:
segments.append(' '.join(segment))

assert len(hyp.timestep['segment']) == len(segments)


def check_subword_timestamps(hyp: Hypothesis, decoding: CTCBPEDecoding):
Expand All @@ -83,7 +91,7 @@ def check_subword_timestamps(hyp: Hypothesis, decoding: CTCBPEDecoding):
assert len(chars) == len(all_chars)

segments_count = sum([hyp.text.count(seperator) for seperator in decoding.segment_seperators])
if hyp.text[-1] not in decoding.segment_seperators:
if not hyp.text or hyp.text[-1] not in decoding.segment_seperators:
segments_count += 1

assert len(hyp.timestep['segment']) == segments_count
Expand Down
20 changes: 14 additions & 6 deletions tests/collections/asr/decoding/test_rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def char_vocabulary():
return [' ', 'a', 'b', 'c', 'd', 'e', 'f']
return [' ', 'a', 'b', 'c', 'd', 'e', 'f', '.']


@pytest.fixture()
Expand Down Expand Up @@ -129,11 +129,19 @@ def check_char_timestamps(hyp: rnnt_utils.Hypothesis, decoding: RNNTDecoding):
words = list(filter(lambda x: x != '', words))
assert len(hyp.timestep['word']) == len(words)

segments_count = sum([hyp.text.count(seperator) for seperator in decoding.segment_seperators])
if hyp.text[-1] not in decoding.segment_seperators:
segments_count += 1
segments = []
segment = []

assert len(hyp.timestep['segment']) == segments_count
for word in words:
segment.append(word)
if word[-1] in decoding.segment_seperators:
segments.append(' '.join(segment))
segment = []

if segment:
segments.append(' '.join(segment))

assert len(hyp.timestep['segment']) == len(segments)


def check_subword_timestamps(hyp: rnnt_utils.Hypothesis, decoding: RNNTBPEDecoding):
Expand All @@ -152,7 +160,7 @@ def check_subword_timestamps(hyp: rnnt_utils.Hypothesis, decoding: RNNTBPEDecodi
assert len(chars) == len(all_chars)

segments_count = sum([hyp.text.count(seperator) for seperator in decoding.segment_seperators])
if hyp.text[-1] not in decoding.segment_seperators:
if not hyp.text or hyp.text[-1] not in decoding.segment_seperators:
segments_count += 1

assert len(hyp.timestep['segment']) == segments_count
Expand Down

0 comments on commit 1403a6b

Please sign in to comment.