From 96194ea0fcaec25bd35099cdedc7fc25f4d4289a Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Mon, 15 Jul 2024 08:39:57 -0700 Subject: [PATCH 1/2] 0.6.5 --- .readthedocs.yaml | 15 +- docs/source/_static/css/mfa.css | 267 ------------------------------- docs/source/changelog.rst | 6 + docs/source/conf.py | 2 +- extensions/decoder/decoder.cpp | 14 ++ kalpy/decoder/training_graphs.py | 215 +++++++++++++++++++------ kalpy/fstext/lexicon.py | 162 ++++++++++--------- rtd_environment.yml | 2 +- 8 files changed, 289 insertions(+), 394 deletions(-) delete mode 100644 docs/source/_static/css/mfa.css diff --git a/.readthedocs.yaml b/.readthedocs.yaml index be6b7b2..bb5b082 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,19 +1,28 @@ version: 2 build: - os: "ubuntu-20.04" + os: "ubuntu-22.04" tools: python: "mambaforge-4.10" + jobs: + post_checkout: + - git fetch --unshallow || true + pre_install: + - git update-index --assume-unchanged .rtd-environment.yml docs/conf.py sphinx: - configuration: docs/source/conf.py + builder: html + configuration: docs/source/conf.py + fail_on_warning: false + +formats: + - htmlzip conda: environment: rtd_environment.yml # This part is necessary otherwise the project is not built python: - version: 3.9 install: - method: pip path: . diff --git a/docs/source/_static/css/mfa.css b/docs/source/_static/css/mfa.css deleted file mode 100644 index a925241..0000000 --- a/docs/source/_static/css/mfa.css +++ /dev/null @@ -1,267 +0,0 @@ - -/* use Gentium Plus - Regular in .woff format */ -@font-face { - font-family: GentiumPlusW; - src: url(../fonts/GentiumPlus-Regular.woff2); -} -/* use Gentium Plus - Italic in .woff2 format */ -@font-face { - font-family: GentiumPlusW; - font-style: italic; - src: url(../fonts/GentiumPlus-Italic.woff2); -} -/* use Gentium Plus - Bold in .woff2 format */ -@font-face { - font-family: GentiumPlusW; - font-weight: bold; - src: url(../fonts/GentiumPlus-Bold.woff2); -} -/* use Gentium Plus - Bold Italic in .woff2 format */ -@font-face { - font-family: GentiumPlusW; - font-weight: bold; - font-style: italic; - src: url(../fonts/GentiumPlus-BoldItalic.woff2); -} - - -:root { - --base-blue: #003566; - --dark-blue: #001D3D; - --very-dark-blue: #000814; - --light-blue: #0E63B3; - --very-light-blue: #7AB5E6; - --base-yellow: #FFC300; - --dark-yellow: #E3930D; - --light-yellow: #FFD60A; -} - -html[data-theme="light"] { - --sd-color-primary: var(--base-blue); - --mfa-admonition-text-color: #cecece; - --sd-color-dark: var(--base-blue); - --sd-color-primary-text: #FFC300; - --sd-color-primary-highlight: #FFC300; - --pst-color-primary: var(--base-blue); - --pst-color-warning: var(--light-yellow); - --pst-color-info: var(--light-blue); - --pst-color-admonition-default: var(--light-blue); - - --pst-color-link: var(--light-blue); - --pst-color-link-hover: var(--dark-blue); - - --pst-color-active-navigation: var(--dark-blue); - --pst-color-hover-navigation: var(--base-yellow); - - --pst-color-navbar-link: var(--base-blue); - --pst-color-navbar-link-hover: var(--pst-color-hover-navigation); - --pst-color-navbar-link-active: var(--pst-color-active-navigation); - - --pst-color-sidebar-link: var(--base-blue); - --pst-color-sidebar-caption: var(--base-blue); - --pst-color-sidebar-link-hover: var(--pst-color-hover-navigation); - --pst-color-sidebar-link-active: var(--pst-color-active-navigation); - - --pst-color-toc-link: var(--base-blue); - --pst-color-toc-link-hover: var(--pst-color-hover-navigation); - --pst-color-toc-link-active: var(--pst-color-active-navigation); -} -/******************************************************************************* -* dark theme -* -* all the variables used for dark theme coloring -*/ -html[data-theme="dark"] { - --sd-color-primary: var(--base-blue); - --sd-color-card-text: var(--base-yellow); - --sd-color-dark: var(--base-blue); - --sd-color-primary-text: var(--base-yellow); - --sd-color-primary-highlight: var(--base-yellow); - --pst-color-primary: var(--base-yellow); - --pst-color-warning: var(--light-yellow); - --pst-color-info: var(--light-blue); ---mfa-admonition-text-color: var(--pst-color-text-base); - --pst-color-link: var(--very-light-blue); - --pst-color-link-hover: var(--light-yellow); - - --pst-color-active-navigation: var(--base-yellow); - --pst-color-hover-navigation: var(--very-light-blue); - - --pst-color-navbar-link: var(--light-blue); - --pst-color-navbar-link-hover: var(--pst-color-hover-navigation); - --pst-color-navbar-link-active: var(--pst-color-active-navigation); - - --pst-color-sidebar-link: var(--base-yellow); - --pst-color-sidebar-caption: var(--base-yellow); - --pst-color-sidebar-link-hover: var(--pst-color-hover-navigation); - --pst-color-sidebar-link-active: var(--pst-color-active-navigation); - - --pst-color-toc-link: var(--base-yellow); - --pst-color-toc-link-hover: var(--pst-color-hover-navigation); - --pst-color-toc-link-active: var(--pst-color-active-navigation); -} - -.container, .container-xl, .container-lg { - max-width: 2400px !important; -} - -.wy-nav-content { - max-width: 1200px !important; -} -.wy-table-responsive table td { - white-space: normal !important; -} -.wy-table-responsive { - overflow: visible !important; -} -.wy-table-responsive table td, -.wy-table-responsive table th { - white-space: normal; -} - - -a.external::after{ -content: "\f35d"; -font-size: 0.75em; -text-align: center; -vertical-align: middle; -padding-bottom: 0.45em; -font-family: "Font Awesome 5 Free"; -font-weight: 900; -} - -/******************************************************************************* -* light theme -* -* all the variables used for light theme coloring -*/ -html[data-theme="light"] { - --base-blue: #003566; - --dark-blue: #001D3D; - --light-blue: #0E63B3; - --base-yellow: #FFC300; - --light-yellow: #FFD60A; - --sd-color-primary: #003566; - --sd-color-dark: #003566; - --sd-color-primary-text: #FFC300; - --sd-color-primary-highlight: #FFC300; - --pst-color-primary: var(--base-blue); - --pst-color-warning: var(--light-yellow); - --pst-color-info: var(--light-blue); - - --pst-color-link: var(--light-blue); - --pst-color-link-hover: var(--dark-blue); - - --pst-color-active-navigation: var(--dark-blue); - --pst-color-hover-navigation: var(--base-yellow); - - --pst-color-navbar-link: var(--base-blue); - --pst-color-navbar-link-hover: var(--pst-color-hover-navigation); - --pst-color-navbar-link-active: var(--pst-color-active-navigation); - - --pst-color-sidebar-link: var(--base-blue); - --pst-color-sidebar-caption: var(--base-blue); - --pst-color-sidebar-link-hover: var(--pst-color-hover-navigation); - --pst-color-sidebar-link-active: var(--pst-color-active-navigation); - - --pst-color-toc-link: var(--base-blue); - --pst-color-toc-link-hover: var(--pst-color-hover-navigation); - --pst-color-toc-link-active: var(--pst-color-active-navigation); -} -/******************************************************************************* -* light theme -* -* all the variables used for light theme coloring -*/ -html[data-theme="dark"] { - --base-blue: #003566; - --dark-blue: #001D3D; - --light-blue: #0E63B3; - --very-light-blue: #7AB5E6; - --base-yellow: #FFC300; - --light-yellow: #FFD60A; - --sd-color-primary: var(--base-blue); - --sd-color-dark: var(--base-blue); - --sd-color-primary-text: var(--base-yellow); - --sd-color-primary-highlight: var(--base-blue); - --pst-color-primary: var(--base-yellow); - --pst-color-warning: var(--light-yellow); - --pst-color-info: var(--light-blue); - - --pst-color-link: var(--very-light-blue); - --pst-color-link-hover: var(--light-yellow); - - --pst-color-active-navigation: var(--base-yellow); - --pst-color-hover-navigation: var(--very-light-blue); - - --pst-color-navbar-link: var(--light-blue); - --pst-color-navbar-link-hover: var(--pst-color-hover-navigation); - --pst-color-navbar-link-active: var(--pst-color-active-navigation); - - --pst-color-sidebar-link: var(--base-yellow); - --pst-color-sidebar-caption: var(--base-yellow); - --pst-color-sidebar-link-hover: var(--pst-color-hover-navigation); - --pst-color-sidebar-link-active: var(--pst-color-active-navigation); - - --pst-color-toc-link: var(--base-yellow); - --pst-color-toc-link-hover: var(--pst-color-hover-navigation); - --pst-color-toc-link-active: var(--pst-color-active-navigation); -} - -.sd-btn-primary{ -font-weight: bold; -} - -.sd-btn-primary:hover{ -background-color: var(--light-blue) !important; -} -.i-navigation{ - color: var(--sd-color-primary); - padding: 20px; -} -html[data-theme="dark"] .i-navigation{ - color: var(--sd-color-primary-text); -} - -.navbar-light .navbar-nav li a.nav-link:{ -font-size: 1.15em; -} - -.rst-table-cell{ -width: 100%; -height: 100%; -display: inline-block; -text-align: center; - -} -div[class*="highlight-"] { - text-align: left; -} - -.ipa-inline { - font-family: "GentiumPlusW"; - font-size: 1.1em; - font-weight: 500; - } - -.ipa-highlight, .ipa-inline { -color: var(--pst-color-inline-code); -} - -.supported { -background-color: #E9F6EC; -} - -.not-supported { -background-color: #FBEAEC; -} -#navbar-icon-links i.fa-github-square::before, i.fa-github-square::before { - color: inherit; -} - -html[data-theme="light"] dt:target { -background-color: var(--base-yellow); -} -html[data-theme="dark"] dt:target{ - background-color: var(--dark-blue); -} diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index effd78f..88079a9 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -4,6 +4,12 @@ Changelog ========= +0.6.5 +----- + +- Changed how the :code:`silence_probability` parameter of LexiconCompiler works with pronunciations that have silence probabilities, so that setting it to 0.0 will ensure that no optional silences are included +- Added the functionality for adding interjection words in between each word in an alignment + 0.6.0 ----- diff --git a/docs/source/conf.py b/docs/source/conf.py index 5e7bcee..742f686 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -362,7 +362,7 @@ html_static_path = ["_static"] html_css_files = [ "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/fontawesome.min.css", - "css/mfa.css", + "https://montreal-forced-aligner.readthedocs.io/en/latest/_static/css/mfa.css", ] # Add any extra paths that contain custom files (such as robots.txt or diff --git a/extensions/decoder/decoder.cpp b/extensions/decoder/decoder.cpp index 537232e..2e5c917 100644 --- a/extensions/decoder/decoder.cpp +++ b/extensions/decoder/decoder.cpp @@ -1770,6 +1770,20 @@ void pybind_training_graph_compiler(py::module &m) { "CompileGraphs allows you to compile a number of graphs at the same " "time. This consumes more memory but is faster.", py::arg("word_fsts"), py::arg("out_fst")) + .def("CompileGraphs", + + [](PyClass& gc, const std::vector *> &word_fsts){ + + py::gil_scoped_release gil_release; + std::vector* > fsts; + + bool ans = gc.CompileGraphs(word_fsts, &fsts); + return fsts; + }, + "CompileGraphs allows you to compile a number of graphs at the same " + "time. This consumes more memory but is faster.", + py::arg("word_fsts"), + py::return_value_policy::take_ownership) .def("CompileGraphFromText", &PyClass::CompileGraphFromText, "This version creates an FST from the text and calls CompileGraph.", diff --git a/kalpy/decoder/training_graphs.py b/kalpy/decoder/training_graphs.py index 50f1ee1..3b7defc 100644 --- a/kalpy/decoder/training_graphs.py +++ b/kalpy/decoder/training_graphs.py @@ -102,9 +102,13 @@ def __init__( if disambiguation_symbols is None: disambiguation_symbols = [] self.disambiguation_symbols = disambiguation_symbols + self._kaldi_fst = self._fst + if not isinstance(self._kaldi_fst, VectorFst): + self._kaldi_fst = VectorFst.from_pynini(self._fst) def __del__(self): del self._compiler + del self._kaldi_fst del self._fst def to_int(self, word: str) -> int: @@ -156,6 +160,8 @@ def export_graphs( transcripts: typing.Iterable[typing.Tuple[str, str]], write_scp: bool = False, callback: typing.Callable = None, + interjection_words: typing.List[str] = None, + cutoff_pattern: str = None, ): """ Export training graphs to a kaldi archive file (i.e., fsts.ark) @@ -168,6 +174,12 @@ def export_graphs( Dictionary of utterance IDs to transcripts write_scp: bool Flag for whether an SCP file should be generated as well + callback: callable, optional + Optional callback function for progress updates + interjection_words: list[str], optional + List of words to add as interjections to the transcripts + cutoff_pattern: str, optional + Cutoff symbol to use for inserting cutoffs before words """ write_specifier = generate_write_specifier(file_name, write_scp) writer = VectorFstWriter(write_specifier) @@ -175,10 +187,15 @@ def export_graphs( transcript_batch = [] num_done = 0 num_error = 0 + logger.debug(f"DISAMBIGUATION: {self.lexicon_compiler.disambiguation}") for key, transcript in transcripts: keys.append(key) if self.use_g2p: transcript_batch.append(transcript) + elif interjection_words: + transcript_batch.append( + self.generate_utterance_graph(transcript, interjection_words, cutoff_pattern) + ) else: transcript_batch.append([self.to_int(x) for x in transcript.split()]) if len(keys) >= self.batch_size: @@ -186,8 +203,14 @@ def export_graphs( fsts = [] for t in transcript_batch: fsts.append(self.compile_fst(t)) + elif interjection_words: + # fsts = [] + # for t in transcript_batch: + # fsts.append(self.compile_fst(t, interjection_words, cutoff_pattern)) + fsts = self.compiler.CompileGraphs(transcript_batch) else: fsts = self.compiler.CompileGraphsFromText(transcript_batch) + del transcript_batch assert len(fsts) == len(keys) batch_done = 0 batch_error = 0 @@ -198,20 +221,29 @@ def export_graphs( batch_error += 1 continue writer.Write(str(key), fst) + del fst batch_done += 1 num_done += batch_done num_error += batch_error + logger.debug(f"Done {num_done} utterances, errors on {num_error}.") if callback: callback(batch_done) keys = [] transcript_batch = [] + del fsts if keys: if self.use_g2p: fsts = [] for t in transcript_batch: fsts.append(self.compile_fst(t)) + elif interjection_words: + # fsts = [] + # for t in transcript_batch: + # fsts.append(self.compile_fst(t, interjection_words, cutoff_pattern)) + fsts = self.compiler.CompileGraphs(transcript_batch) else: fsts = self.compiler.CompileGraphsFromText(transcript_batch) + del transcript_batch assert len(fsts) == len(keys) batch_done = 0 batch_error = 0 @@ -223,15 +255,133 @@ def export_graphs( continue writer.Write(str(key), fst) batch_done += 1 + del fst num_done += batch_done num_error += batch_error + del fsts if callback: callback(batch_done) writer.Close() logger.info(f"Done {num_done} utterances, errors on {num_error}.") + def generate_utterance_graph( + self, + transcript: str, + interjection_words: typing.List[str] = None, + cutoff_pattern: str = None, + ) -> typing.Optional[VectorFst]: + if interjection_words is None: + interjection_words = [] + default_interjection_cost = 3.0 + cutoff_interjection_cost = default_interjection_cost + cutoff_symbol = -1 + if cutoff_pattern is not None and self.word_table.member(cutoff_pattern): + cutoff_symbol = self.to_int(cutoff_pattern) + interjection_costs = {} + if interjection_words: + for iw in interjection_words: + if not self.word_table.member(iw): + continue + if isinstance(interjection_words, dict): + interjection_cost = interjection_words[iw] * default_interjection_cost + else: + interjection_cost = default_interjection_cost + if isinstance(iw, str): + iw = self.to_int(iw) + if iw == cutoff_symbol: + cutoff_interjection_cost = interjection_cost + continue + interjection_costs[iw] = interjection_cost + g = pynini.Fst() + start_state = g.add_state() + g.set_start(start_state) + if isinstance(transcript, str): + transcript = transcript.split() + for word_symbol in transcript: + if not isinstance(word_symbol, int): + word_symbol = self.to_int(word_symbol) + interjection_state = g.add_state() + for iw_symbol, interjection_cost in interjection_costs.items(): + g.add_arc( + start_state, + pywrapfst.Arc( + iw_symbol, + iw_symbol, + pywrapfst.Weight(g.weight_type(), interjection_cost), + interjection_state, + ), + ) + if cutoff_pattern is not None: + cutoff_word = f"{cutoff_pattern[:-1]}-{self.word_table.find(word_symbol)}{cutoff_pattern[-1]}" + if self.word_table.member(cutoff_word): + iw_symbol = self.to_int(cutoff_word) + g.add_arc( + start_state, + pywrapfst.Arc( + iw_symbol, + iw_symbol, + pywrapfst.Weight(g.weight_type(), cutoff_interjection_cost), + interjection_state, + ), + ) + g.add_arc( + start_state, + pywrapfst.Arc( + word_symbol, + word_symbol, + pywrapfst.Weight(g.weight_type(), default_interjection_cost), + interjection_state, + ), + ) + g.add_arc( + start_state, + pywrapfst.Arc( + self.word_table.find(""), + self.word_table.find(""), + pywrapfst.Weight(g.weight_type(), 1.0), + interjection_state, + ), + ) + final_state = g.add_state() + g.add_arc( + interjection_state, + pywrapfst.Arc( + word_symbol, + word_symbol, + pywrapfst.Weight.one(g.weight_type()), + final_state, + ), + ) + start_state = final_state + final_state = g.add_state() + for iw_symbol, interjection_cost in interjection_costs.items(): + g.add_arc( + start_state, + pywrapfst.Arc( + iw_symbol, + iw_symbol, + pywrapfst.Weight(g.weight_type(), interjection_cost), + final_state, + ), + ) + g.add_arc( + start_state, + pywrapfst.Arc( + self.word_table.find(""), + self.word_table.find(""), + pywrapfst.Weight.one(g.weight_type()), + final_state, + ), + ) + g.set_final(final_state, pywrapfst.Weight.one(g.weight_type())) + g = VectorFst.from_pynini(g) + return g + def compile_fst( - self, transcript: str, interjection_words: typing.List[str] = None + self, + transcript: str, + interjection_words: typing.List[str] = None, + cutoff_pattern: str = None, ) -> typing.Optional[VectorFst]: """ Compile a transcript to a training graph @@ -240,6 +390,10 @@ def compile_fst( ---------- transcript: str Orthographic transcript to compile + interjection_words: list[str], optional + List of words to add as interjections to the transcript + cutoff_pattern: str, optional + Cutoff symbol to use for inserting cutoffs before words Returns ------- @@ -247,7 +401,6 @@ def compile_fst( Training graph of transcript """ if self.use_g2p: - g_fst = pynini.accep(transcript, token_type=self.word_table) lg_fst = pynini.compose(g_fst, self._fst, compose_filter="alt_sequence") lg_fst = lg_fst.project("output").rmepsilon() @@ -286,54 +439,18 @@ def compile_fst( fst, self.transition_model, disambig_syms_in, self.options.self_loop_scale ) elif interjection_words: - g = pynini.Fst() - start_state = g.add_state() - g.set_start(start_state) - for w in transcript.split(): - word_symbol = self.to_int(w) - word_initial_state = g.add_state() - for iw in interjection_words: - if not self.lexicon_compiler.word_table.member(iw): - continue - iw_symbol = self.to_int(iw) - g.add_arc( - word_initial_state - 1, - pywrapfst.Arc( - iw_symbol, - iw_symbol, - pywrapfst.Weight(g.weight_type(), 4.0), - word_initial_state, - ), - ) - word_final_state = g.add_state() - g.add_arc( - word_initial_state, - pywrapfst.Arc( - word_symbol, - word_symbol, - pywrapfst.Weight.one(g.weight_type()), - word_final_state, - ), - ) - g.add_arc( - word_initial_state - 1, - pywrapfst.Arc( - word_symbol, - word_symbol, - pywrapfst.Weight.one(g.weight_type()), - word_final_state, - ), - ) - g.set_final(word_final_state, pywrapfst.Weight.one(g.weight_type())) - - lg = pynini.compose(self.lexicon_compiler.fst, g) - lg.optimize() - lg.arcsort("olabel") - lg_fst = VectorFst.from_pynini(lg) + g = self.generate_utterance_graph(transcript, interjection_words, cutoff_pattern) + # fst = VectorFst() + # self.compiler.CompileGraph(g, fst) + # lg_fst = pynini.compose(self._fst, g, compose_filter="alt_sequence") + # lg_fst = VectorFst.from_pynini(lg_fst) + lg_fst = fst_table_compose(self._kaldi_fst, g) - disambig_syms_in = [] - if self.lexicon_compiler is not None and self.lexicon_compiler.disambiguation: - disambig_syms_in = self.lexicon_compiler.disambiguation_symbols + disambig_syms_in = ( + [] + if not self.lexicon_compiler.disambiguation + else self.lexicon_compiler.disambiguation_symbols + ) lg_fst = fst_determinize_star(lg_fst, use_log=True) fst_minimize_encoded(lg_fst) fst_push_special(lg_fst) diff --git a/kalpy/fstext/lexicon.py b/kalpy/fstext/lexicon.py index 22fc8c6..c9e6ce7 100644 --- a/kalpy/fstext/lexicon.py +++ b/kalpy/fstext/lexicon.py @@ -194,7 +194,11 @@ def __init__( self.word_end_label = word_end_label self.start_state = 0 self.non_silence_state = 1 - self.silence_state = 2 + + if self.silence_probability: + self.silence_state = 2 + else: + self.silence_state = None def clear(self): self.pronunciations = [] @@ -373,8 +377,9 @@ def create_fsts(self, phonological_rule_fst: pynini.Fst = None): self._align_fst.add_state() # Silence state = 2 - self._fst.add_state() - self._align_fst.add_state() + if self.silence_probability: + self._fst.add_state() + self._align_fst.add_state() self._align_fst.set_start(self.start_state) # initial no silence @@ -397,40 +402,46 @@ def create_fsts(self, phonological_rule_fst: pynini.Fst = None): ), ) # initial silence - self._fst.add_arc( - self.start_state, - pywrapfst.Arc( - self.phone_table.find(self.silence_phone), - self.word_table.find(self.silence_word), - pywrapfst.Weight(self._fst.weight_type(), initial_silence_cost), - self.silence_state, - ), - ) - self._align_fst.add_arc( - self.start_state, - pywrapfst.Arc( - self.phone_table.find(self.silence_phone), - self.word_table.find(self.silence_word), - pywrapfst.Weight(self._align_fst.weight_type(), initial_silence_cost), - self.silence_state, - ), - ) + if self.silence_probability: + self._fst.add_arc( + self.start_state, + pywrapfst.Arc( + self.phone_table.find(self.silence_phone), + self.word_table.find(self.silence_word), + pywrapfst.Weight(self._fst.weight_type(), initial_silence_cost), + self.silence_state, + ), + ) + self._align_fst.add_arc( + self.start_state, + pywrapfst.Arc( + self.phone_table.find(self.silence_phone), + self.word_table.find(self.silence_word), + pywrapfst.Weight(self._align_fst.weight_type(), initial_silence_cost), + self.silence_state, + ), + ) for pron in self.pronunciations: self.add_pronunciation(pron, phonological_rule_fst) - if final_silence_cost > 0: - self._fst.set_final( - self.silence_state, pywrapfst.Weight(self._fst.weight_type(), final_silence_cost) - ) - self._align_fst.set_final( - self.silence_state, - pywrapfst.Weight(self._align_fst.weight_type(), final_silence_cost), - ) - else: - self._fst.set_final(self.silence_state, pywrapfst.Weight.one(self._fst.weight_type())) - self._align_fst.set_final( - self.silence_state, pywrapfst.Weight.one(self._align_fst.weight_type()) - ) + + if self.silence_probability: + if final_silence_cost > 0: + self._fst.set_final( + self.silence_state, + pywrapfst.Weight(self._fst.weight_type(), final_silence_cost), + ) + self._align_fst.set_final( + self.silence_state, + pywrapfst.Weight(self._align_fst.weight_type(), final_silence_cost), + ) + else: + self._fst.set_final( + self.silence_state, pywrapfst.Weight.one(self._fst.weight_type()) + ) + self._align_fst.set_final( + self.silence_state, pywrapfst.Weight.one(self._align_fst.weight_type()) + ) if final_non_silence_cost > 0: self._fst.set_final( self.non_silence_state, @@ -449,7 +460,8 @@ def create_fsts(self, phonological_rule_fst: pynini.Fst = None): ) if ( - self._fst.num_states() <= self.silence_state + 1 + self._fst.num_states() + <= (self.silence_state if self.silence_probability else self.non_silence_state) + 1 or self._fst.start() == pywrapfst.NO_STATE_ID ): num_words = self.word_table.num_symbols() @@ -653,17 +665,18 @@ def add_pronunciation( ), ) # Silence before the pronunciation - self._fst.add_arc( - self.silence_state, - pywrapfst.Arc( - arc.ilabel, - word_symbol, - pywrapfst.Weight( - self._fst.weight_type(), pron_cost + silence_before_cost + if self.silence_probability: + self._fst.add_arc( + self.silence_state, + pywrapfst.Arc( + arc.ilabel, + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + silence_before_cost + ), + arc.nextstate + start_index, ), - arc.nextstate + start_index, - ), - ) + ) # No silence before the pronunciation self._align_fst.add_arc( @@ -678,17 +691,18 @@ def add_pronunciation( ), ) # Silence before the pronunciation - self._align_fst.add_arc( - self.silence_state, - pywrapfst.Arc( - self.phone_table.find(self.word_begin_label), - word_symbol, - pywrapfst.Weight( - self._fst.weight_type(), pron_cost + silence_before_cost + if self.silence_probability: + self._align_fst.add_arc( + self.silence_state, + pywrapfst.Arc( + self.phone_table.find(self.word_begin_label), + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + silence_before_cost + ), + arc.nextstate + align_start_index - 1, ), - arc.nextstate + align_start_index - 1, - ), - ) + ) else: self._fst.add_arc( state + start_index, @@ -733,15 +747,16 @@ def add_pronunciation( ), ) # Silence following the pronunciation - self._fst.add_arc( - num_new_states + start_index, - pywrapfst.Arc( - self.phone_table.find(self.silence_phone), - word_eps_symbol, - pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), - self.silence_state, - ), - ) + if self.silence_probability: + self._fst.add_arc( + num_new_states + start_index, + pywrapfst.Arc( + self.phone_table.find(self.silence_phone), + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), + self.silence_state, + ), + ) self._align_fst.add_arc( num_new_states + align_start_index, pywrapfst.Arc( @@ -763,15 +778,16 @@ def add_pronunciation( ), ) # Silence following the pronunciation - self._align_fst.add_arc( - num_new_states + align_start_index + 1, - pywrapfst.Arc( - self.phone_table.find(self.silence_phone), - word_eps_symbol, - pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), - self.silence_state, - ), - ) + if self.silence_probability: + self._align_fst.add_arc( + num_new_states + align_start_index + 1, + pywrapfst.Arc( + self.phone_table.find(self.silence_phone), + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), + self.silence_state, + ), + ) @property def kaldi_fst(self) -> VectorFst: diff --git a/rtd_environment.yml b/rtd_environment.yml index 21c94d2..c42921c 100644 --- a/rtd_environment.yml +++ b/rtd_environment.yml @@ -17,4 +17,4 @@ dependencies: - mock - setuptools-scm - kaldi =*=cpu* - - kalpy + - kalpy =*=cpu* From 323bf49f0a9a7a422d0cc2e4b3655a469d2c7bf1 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Mon, 15 Jul 2024 09:16:48 -0700 Subject: [PATCH 2/2] Update TrainingGraphCompiler signature --- docs/source/changelog.rst | 3 +- kalpy/decoder/training_graphs.py | 50 +++++++------------------------- tests/test_decoder.py | 8 ++--- 3 files changed, 14 insertions(+), 47 deletions(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 88079a9..1644ae6 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -7,7 +7,8 @@ Changelog 0.6.5 ----- -- Changed how the :code:`silence_probability` parameter of LexiconCompiler works with pronunciations that have silence probabilities, so that setting it to 0.0 will ensure that no optional silences are included +- Changed how the :code:`silence_probability` parameter of :code:`LexiconCompiler` works with pronunciations that have silence probabilities, so that setting it to 0.0 will ensure that no optional silences are included +- Changed :code:`TrainingGraphCompiler` signature to require a :code:`LexiconCompiler` rather than an FST/path and a word table - Added the functionality for adding interjection words in between each word in an alignment 0.6.0 diff --git a/kalpy/decoder/training_graphs.py b/kalpy/decoder/training_graphs.py index 3b7defc..8e78436 100644 --- a/kalpy/decoder/training_graphs.py +++ b/kalpy/decoder/training_graphs.py @@ -42,7 +42,7 @@ class TrainingGraphCompiler: Path to model file tree_path: str Path to tree file - lexicon: typing.Union[pathlib.Path, str, :class:`~kalpy.fstext.lexicon.LexiconCompiler`, VectorFst] + lexicon_compiler: :class:`~kalpy.fstext.lexicon.LexiconCompiler` Lexicon compiler to use in generating training graphs transition_scale: float Scale on transitions, typically set to 0 as it will be defined during alignment @@ -63,8 +63,7 @@ def __init__( self, model_path: typing.Union[pathlib.Path, str], tree_path: typing.Union[pathlib.Path, str], - lexicon: typing.Union[pathlib.Path, str, LexiconCompiler, VectorFst], - words_symbols: typing.Union[pathlib.Path, str, pywrapfst.SymbolTable], + lexicon_compiler: LexiconCompiler, transition_scale: float = 0.0, self_loop_scale: float = 0.0, batch_size: int = 1000, @@ -81,35 +80,22 @@ def __init__( self._compiler = None self.use_g2p = use_g2p self.lexicon_path = None - self.lexicon_compiler = None - self._fst = None - if isinstance(lexicon, LexiconCompiler): - self.lexicon_compiler = lexicon - if self.use_g2p: - self._fst = self.lexicon_compiler.fst - else: - self._fst = self.lexicon_compiler.fst - disambiguation_symbols = self.lexicon_compiler.disambiguation_symbols - elif isinstance(lexicon, VectorFst): - self._fst = lexicon - else: - self.lexicon_path = str(lexicon) - if isinstance(words_symbols, pywrapfst.SymbolTable): - self.word_table = words_symbols - else: - self.word_table = pywrapfst.SymbolTable.read_text(words_symbols) + self.lexicon_compiler = lexicon_compiler self.oov_word = oov_word if disambiguation_symbols is None: disambiguation_symbols = [] self.disambiguation_symbols = disambiguation_symbols - self._kaldi_fst = self._fst + self._kaldi_fst = self.lexicon_compiler.fst if not isinstance(self._kaldi_fst, VectorFst): - self._kaldi_fst = VectorFst.from_pynini(self._fst) + self._kaldi_fst = VectorFst.from_pynini(self._kaldi_fst) def __del__(self): del self._compiler del self._kaldi_fst - del self._fst + + @property + def word_table(self): + return self.lexicon_compiler.word_table def to_int(self, word: str) -> int: """ @@ -129,26 +115,16 @@ def to_int(self, word: str) -> int: return self.word_table.find(word) return self.word_table.find(self.oov_word) - @property - def fst(self): - if self._fst is None: - return pynini.Fst.read(self.lexicon_path) - @property def compiler(self): if self._compiler is None: - if self._fst is None: - if self.lexicon_compiler is None: - self._fst = pynini.Fst.read(str(self.lexicon_path)) - else: - self._fst = self.lexicon_compiler.fst disambiguation_symbols = [] if self.lexicon_compiler is not None and self.lexicon_compiler.disambiguation: disambiguation_symbols = self.lexicon_compiler.disambiguation_symbols self._compiler = _TrainingGraphCompiler( self.transition_model, self.tree, - VectorFst.from_pynini(self._fst), + self._kaldi_fst, disambiguation_symbols, self.options, ) @@ -204,9 +180,6 @@ def export_graphs( for t in transcript_batch: fsts.append(self.compile_fst(t)) elif interjection_words: - # fsts = [] - # for t in transcript_batch: - # fsts.append(self.compile_fst(t, interjection_words, cutoff_pattern)) fsts = self.compiler.CompileGraphs(transcript_batch) else: fsts = self.compiler.CompileGraphsFromText(transcript_batch) @@ -237,9 +210,6 @@ def export_graphs( for t in transcript_batch: fsts.append(self.compile_fst(t)) elif interjection_words: - # fsts = [] - # for t in transcript_batch: - # fsts.append(self.compile_fst(t, interjection_words, cutoff_pattern)) fsts = self.compiler.CompileGraphs(transcript_batch) else: fsts = self.compiler.CompileGraphsFromText(transcript_batch) diff --git a/tests/test_decoder.py b/tests/test_decoder.py index f2bdd67..068acb1 100644 --- a/tests/test_decoder.py +++ b/tests/test_decoder.py @@ -22,9 +22,7 @@ def test_training_graphs( lc = LexiconCompiler(position_dependent_phones=False) lc.load_pronunciations(dictionary_path) lc.fst.write(str(mono_temp_dir.joinpath("lexicon.fst"))) - gc = TrainingGraphCompiler( - mono_model_path, mono_tree_path, str(mono_temp_dir.joinpath("lexicon.fst")), lc.word_table - ) + gc = TrainingGraphCompiler(mono_model_path, mono_tree_path, lc) graph = kaldi_to_pynini(gc.compile_fst(acoustic_corpus_text)) assert graph.num_states() > 0 assert graph.start() != pywrapfst.NO_STATE_ID @@ -51,9 +49,7 @@ def test_training_graphs_sat( lc.fst.write(str(sat_temp_dir.joinpath("L_debug.fst"))) lc.word_table.write_text(str(sat_temp_dir.joinpath("words.txt"))) lc.phone_table.write_text(str(sat_temp_dir.joinpath("phones.txt"))) - gc = TrainingGraphCompiler( - sat_model_path, sat_tree_path, str(sat_temp_dir.joinpath("L_debug.fst")), lc.word_table - ) + gc = TrainingGraphCompiler(sat_model_path, sat_tree_path, lc) graph = kaldi_to_pynini(gc.compile_fst(acoustic_corpus_text)) assert graph.num_states() > 0 assert graph.start() != pywrapfst.NO_STATE_ID