Skip to content

Commit

Permalink
Update lexicon phone to pronunciations (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcauliffe authored Feb 3, 2024
1 parent 33307be commit d812d37
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ if (MSVC)
set(CMAKE_INSTALL_RPATH "$ORIGIN;$ORIGIN/../lib;$ORIGIN/../../tools/openfst/lib")
endif ()

find_package(CUDA)
find_package(CUDAToolkit)

find_package(pybind11 REQUIRED)
include_directories(extensions)
Expand Down Expand Up @@ -86,7 +86,7 @@ target_link_libraries(_kalpy PUBLIC kaldi-base kaldi-chain
fstscript
)

if(CUDA_FOUND)
if(CUDAToolkit_FOUND)

target_link_libraries(_kalpy PUBLIC kaldi-cudadecoder kaldi-cudafeat
)
Expand Down
5 changes: 5 additions & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
Changelog
=========

0.6.0
-----

- Fixed a bug in feature archives where fMLLR transforms were being ignored

0.5.1
-----

Expand Down
14 changes: 6 additions & 8 deletions kalpy/fstext/lexicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,10 +916,10 @@ def _create_pronunciation_string(

def phones_to_pronunciations(
self,
text: str,
word_symbols: typing.List[int],
intervals: typing.List[CtmInterval],
transcription: bool = False,
text: str = None,
) -> HierarchicalCtm:

phones = [x.symbol for x in intervals]
Expand All @@ -928,12 +928,10 @@ def phones_to_pronunciations(
phones,
transcription=transcription,
)
if transcription:
actual_words = [self.word_table.find(x) for x in word_symbols]
if not text:
text = " ".join(actual_words)
else:
actual_words = [x for x in text.split() if x != self.silence_word]

actual_words = [self.word_table.find(x) for x in word_symbols]
if not text:
text = " ".join(actual_words)
word_intervals = []
current_phone_index = 0
current_word_index = 0
Expand Down Expand Up @@ -1002,10 +1000,10 @@ def __init__(

def phones_to_pronunciations(
self,
text: str,
word_symbols: typing.List[int],
intervals: typing.List[CtmInterval],
transcription: bool = False,
text: str = None,
) -> HierarchicalCtm:
phone_symbols = [x.symbol for x in intervals]
word_symbols = [self.word_table.find(x) for x in text.split()]
Expand Down
6 changes: 2 additions & 4 deletions tests/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def test_align_sat_first_pass(
alignment = alignment_archive["1-1"]
assert len(alignment.alignment) == 2672
intervals = alignment.generate_ctm(aligner.transition_model, lc.phone_table)
text = " ".join(lc.word_table.find(x) for x in alignment.words)
ctm = lc.phones_to_pronunciations(text, alignment.words, intervals)
ctm = lc.phones_to_pronunciations(alignment.words, intervals)
ctm.export_textgrid(textgrid_name, file_duration=26.72)
reference_alignment_archive = AlignmentArchive(reference_first_pass_ali_path)
reference_alignment = reference_alignment_archive["1-1"]
Expand Down Expand Up @@ -129,8 +128,7 @@ def test_align_sat_second_pass(
alignment = alignment_archive["1-1"]
assert len(alignment.alignment) == 2672
intervals = alignment.generate_ctm(aligner.transition_model, lc.phone_table)
text = " ".join(lc.word_table.find(x) for x in alignment.words)
ctm = lc.phones_to_pronunciations(text, alignment.words, intervals)
ctm = lc.phones_to_pronunciations(alignment.words, intervals)
ctm.export_textgrid(textgrid_name, file_duration=26.72)

reference_alignment_archive = AlignmentArchive(reference_second_pass_ali_path)
Expand Down
6 changes: 2 additions & 4 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def test_decode_sat_first_pass(
alignment = alignment_archive["1-1"]
assert len(alignment.alignment) == 2672
intervals = alignment.generate_ctm(decoder.transition_model, lc.phone_table)
text = " ".join(lc.word_table.find(x) for x in alignment.words)
ctm = lc.phones_to_pronunciations(text, alignment.words, intervals)
ctm = lc.phones_to_pronunciations(alignment.words, intervals)
ctm.export_textgrid(textgrid_name, file_duration=26.72)


Expand Down Expand Up @@ -128,8 +127,7 @@ def test_decode_sat_second_pass(
alignment = alignment_archive["1-1"]
assert len(alignment.alignment) == 2672
intervals = alignment.generate_ctm(decoder.transition_model, lc.phone_table)
text = " ".join(lc.word_table.find(x) for x in alignment.words)
ctm = lc.phones_to_pronunciations(text, alignment.words, intervals)
ctm = lc.phones_to_pronunciations(alignment.words, intervals)
ctm.export_textgrid(textgrid_name, file_duration=26.72)


Expand Down

0 comments on commit d812d37

Please sign in to comment.