diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index fd16aca67713..7a16db4324bc 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -31,7 +31,7 @@ def char_vocabulary(): - return [' ', 'a', 'b', 'c', 'd', 'e', 'f'] + return [' ', 'a', 'b', 'c', 'd', 'e', 'f', '.'] @pytest.fixture() @@ -60,11 +60,19 @@ def check_char_timestamps(hyp: Hypothesis, decoding: CTCDecoding): words = list(filter(lambda x: x != '', words)) assert len(hyp.timestep['word']) == len(words) - segments_count = sum([hyp.text.count(seperator) for seperator in decoding.segment_seperators]) - if hyp.text[-1] not in decoding.segment_seperators: - segments_count += 1 + segments = [] + segment = [] - assert len(hyp.timestep['segment']) == segments_count + for word in words: + segment.append(word) + if word[-1] in decoding.segment_seperators: + segments.append(' '.join(segment)) + segment = [] + + if segment: + segments.append(' '.join(segment)) + + assert len(hyp.timestep['segment']) == len(segments) def check_subword_timestamps(hyp: Hypothesis, decoding: CTCBPEDecoding): @@ -83,7 +91,7 @@ def check_subword_timestamps(hyp: Hypothesis, decoding: CTCBPEDecoding): assert len(chars) == len(all_chars) segments_count = sum([hyp.text.count(seperator) for seperator in decoding.segment_seperators]) - if hyp.text[-1] not in decoding.segment_seperators: + if not hyp.text or hyp.text[-1] not in decoding.segment_seperators: segments_count += 1 assert len(hyp.timestep['segment']) == segments_count diff --git a/tests/collections/asr/decoding/test_rnnt_decoding.py b/tests/collections/asr/decoding/test_rnnt_decoding.py index 59da7b11d286..82b5d00bede6 100644 --- a/tests/collections/asr/decoding/test_rnnt_decoding.py +++ b/tests/collections/asr/decoding/test_rnnt_decoding.py @@ -35,7 +35,7 @@ def char_vocabulary(): - return [' ', 'a', 'b', 'c', 'd', 'e', 'f'] + return [' ', 'a', 'b', 'c', 'd', 'e', 'f', '.'] @pytest.fixture() @@ -129,11 +129,19 @@ def check_char_timestamps(hyp: rnnt_utils.Hypothesis, decoding: RNNTDecoding): words = list(filter(lambda x: x != '', words)) assert len(hyp.timestep['word']) == len(words) - segments_count = sum([hyp.text.count(seperator) for seperator in decoding.segment_seperators]) - if hyp.text[-1] not in decoding.segment_seperators: - segments_count += 1 + segments = [] + segment = [] - assert len(hyp.timestep['segment']) == segments_count + for word in words: + segment.append(word) + if word[-1] in decoding.segment_seperators: + segments.append(' '.join(segment)) + segment = [] + + if segment: + segments.append(' '.join(segment)) + + assert len(hyp.timestep['segment']) == len(segments) def check_subword_timestamps(hyp: rnnt_utils.Hypothesis, decoding: RNNTBPEDecoding): @@ -152,7 +160,7 @@ def check_subword_timestamps(hyp: rnnt_utils.Hypothesis, decoding: RNNTBPEDecodi assert len(chars) == len(all_chars) segments_count = sum([hyp.text.count(seperator) for seperator in decoding.segment_seperators]) - if hyp.text[-1] not in decoding.segment_seperators: + if not hyp.text or hyp.text[-1] not in decoding.segment_seperators: segments_count += 1 assert len(hyp.timestep['segment']) == segments_count