From d812d3746a0a585fe8c9ae02b2d5ba29848f94c8 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sat, 3 Feb 2024 14:02:04 -0800 Subject: [PATCH] Update lexicon phone to pronunciations (#13) --- CMakeLists.txt | 4 ++-- docs/source/changelog.rst | 5 +++++ kalpy/fstext/lexicon.py | 14 ++++++-------- tests/test_align.py | 6 ++---- tests/test_decode.py | 6 ++---- 5 files changed, 17 insertions(+), 18 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9da52e7..22bf605 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,7 +39,7 @@ if (MSVC) set(CMAKE_INSTALL_RPATH "$ORIGIN;$ORIGIN/../lib;$ORIGIN/../../tools/openfst/lib") endif () -find_package(CUDA) +find_package(CUDAToolkit) find_package(pybind11 REQUIRED) include_directories(extensions) @@ -86,7 +86,7 @@ target_link_libraries(_kalpy PUBLIC kaldi-base kaldi-chain fstscript ) -if(CUDA_FOUND) +if(CUDAToolkit_FOUND) target_link_libraries(_kalpy PUBLIC kaldi-cudadecoder kaldi-cudafeat ) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 9821e96..effd78f 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -4,6 +4,11 @@ Changelog ========= +0.6.0 +----- + +- Fixed a bug in feature archives where fMLLR transforms were being ignored + 0.5.1 ----- diff --git a/kalpy/fstext/lexicon.py b/kalpy/fstext/lexicon.py index bb806cf..62a7e46 100644 --- a/kalpy/fstext/lexicon.py +++ b/kalpy/fstext/lexicon.py @@ -916,10 +916,10 @@ def _create_pronunciation_string( def phones_to_pronunciations( self, - text: str, word_symbols: typing.List[int], intervals: typing.List[CtmInterval], transcription: bool = False, + text: str = None, ) -> HierarchicalCtm: phones = [x.symbol for x in intervals] @@ -928,12 +928,10 @@ def phones_to_pronunciations( phones, transcription=transcription, ) - if transcription: - actual_words = [self.word_table.find(x) for x in word_symbols] - if not text: - text = " ".join(actual_words) - else: - actual_words = [x for x in text.split() if x != self.silence_word] + + actual_words = [self.word_table.find(x) for x in word_symbols] + if not text: + text = " ".join(actual_words) word_intervals = [] current_phone_index = 0 current_word_index = 0 @@ -1002,10 +1000,10 @@ def __init__( def phones_to_pronunciations( self, - text: str, word_symbols: typing.List[int], intervals: typing.List[CtmInterval], transcription: bool = False, + text: str = None, ) -> HierarchicalCtm: phone_symbols = [x.symbol for x in intervals] word_symbols = [self.word_table.find(x) for x in text.split()] diff --git a/tests/test_align.py b/tests/test_align.py index 2aa6254..089283d 100644 --- a/tests/test_align.py +++ b/tests/test_align.py @@ -76,8 +76,7 @@ def test_align_sat_first_pass( alignment = alignment_archive["1-1"] assert len(alignment.alignment) == 2672 intervals = alignment.generate_ctm(aligner.transition_model, lc.phone_table) - text = " ".join(lc.word_table.find(x) for x in alignment.words) - ctm = lc.phones_to_pronunciations(text, alignment.words, intervals) + ctm = lc.phones_to_pronunciations(alignment.words, intervals) ctm.export_textgrid(textgrid_name, file_duration=26.72) reference_alignment_archive = AlignmentArchive(reference_first_pass_ali_path) reference_alignment = reference_alignment_archive["1-1"] @@ -129,8 +128,7 @@ def test_align_sat_second_pass( alignment = alignment_archive["1-1"] assert len(alignment.alignment) == 2672 intervals = alignment.generate_ctm(aligner.transition_model, lc.phone_table) - text = " ".join(lc.word_table.find(x) for x in alignment.words) - ctm = lc.phones_to_pronunciations(text, alignment.words, intervals) + ctm = lc.phones_to_pronunciations(alignment.words, intervals) ctm.export_textgrid(textgrid_name, file_duration=26.72) reference_alignment_archive = AlignmentArchive(reference_second_pass_ali_path) diff --git a/tests/test_decode.py b/tests/test_decode.py index 0c11313..3995ba3 100644 --- a/tests/test_decode.py +++ b/tests/test_decode.py @@ -77,8 +77,7 @@ def test_decode_sat_first_pass( alignment = alignment_archive["1-1"] assert len(alignment.alignment) == 2672 intervals = alignment.generate_ctm(decoder.transition_model, lc.phone_table) - text = " ".join(lc.word_table.find(x) for x in alignment.words) - ctm = lc.phones_to_pronunciations(text, alignment.words, intervals) + ctm = lc.phones_to_pronunciations(alignment.words, intervals) ctm.export_textgrid(textgrid_name, file_duration=26.72) @@ -128,8 +127,7 @@ def test_decode_sat_second_pass( alignment = alignment_archive["1-1"] assert len(alignment.alignment) == 2672 intervals = alignment.generate_ctm(decoder.transition_model, lc.phone_table) - text = " ".join(lc.word_table.find(x) for x in alignment.words) - ctm = lc.phones_to_pronunciations(text, alignment.words, intervals) + ctm = lc.phones_to_pronunciations(alignment.words, intervals) ctm.export_textgrid(textgrid_name, file_duration=26.72)