From f0264513423d2faeb00f02bd24e88e3707ab3d50 Mon Sep 17 00:00:00 2001 From: danemadsen Date: Tue, 13 Aug 2024 17:21:13 +1000 Subject: [PATCH] work on punct tokens --- include/babylon.hpp | 4 ++-- src/phonemizer.cpp | 46 +++++++++++++++++++++++++++++---------------- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/include/babylon.hpp b/include/babylon.hpp index 7f0d364..e9ad190 100644 --- a/include/babylon.hpp +++ b/include/babylon.hpp @@ -32,6 +32,7 @@ namespace DeepPhonemizer { ~Session(); std::vector g2p(const std::string& text); + std::vector g2p_tokens(const std::string& text); private: std::string lang; @@ -39,9 +40,8 @@ namespace DeepPhonemizer { Ort::Session* session; SequenceTokenizer* text_tokenizer; SequenceTokenizer* phoneme_tokenizer; - std::unordered_map> dictionary; - std::vector g2p_internal(const std::string& text); + std::vector g2p_tokens_internal(const std::string& text); }; std::vector clean_text(const std::string& text); diff --git a/src/phonemizer.cpp b/src/phonemizer.cpp index a3e81d1..d84db8a 100644 --- a/src/phonemizer.cpp +++ b/src/phonemizer.cpp @@ -8,6 +8,15 @@ const std::array input_names = {"text"}; const std::array output_names = {"output"}; +const std::vector punctuation_symbols = { + ".", + ",", + ":", + ";", + "!", + "?" +}; + std::vector softmax(const std::vector& logits) { float max_logit = *std::max_element(logits.begin(), logits.end()); std::vector probabilities(logits.size()); @@ -51,6 +60,10 @@ namespace DeepPhonemizer { for (const auto& symbol : symbols) { tokens.push_back(symbol); } + + for (const auto& symbol : punctuation_symbols) { + tokens.push_back(symbol); + } } std::vector SequenceTokenizer::operator()(const std::string& sentence, const std::string& language) const { @@ -189,41 +202,45 @@ namespace DeepPhonemizer { } std::vector Session::g2p(const std::string& text) { + // Convert input text to phonemes + std::vector phoneme_tokens = g2p_tokens(text); + + // Decode the phoneme tokens + return phoneme_tokenizer->decode(phoneme_tokens, true);; + } + + std::vector Session::g2p_tokens(const std::string& text) { // Clean the input text std::vector words = clean_text(text); // Convert each word to phonemes - std::vector phonemes; + std::vector phoneme_ids; for (const auto& word : words) { - std::vector word_phonemes = g2p_internal(word); + std::vector word_phoneme_ids = g2p_tokens_internal(word); - phonemes.insert(phonemes.end(), word_phonemes.begin(), word_phonemes.end()); + phoneme_ids.insert(phoneme_ids.end(), word_phoneme_ids.begin(), word_phoneme_ids.end()); if (punctuation) { // Check if the word ends with punctuation if (std::ispunct(word.back())) { - phonemes.push_back(std::string(1, word.back())); + auto punct_token = phoneme_tokenizer->operator()(std::string(1, word.back()), lang); + phoneme_ids.insert(phoneme_ids.end(), punct_token.begin(), punct_token.end()); } } - phonemes.push_back(" "); + phoneme_ids.push_back(0); } - return phonemes; + return phoneme_ids; } - std::vector Session::g2p_internal(const std::string& text) { + std::vector Session::g2p_tokens_internal(const std::string& text) { // Check if the input text is longer than one character std::string key_text = text; std::transform(key_text.begin(), key_text.end(), key_text.begin(), ::tolower); key_text.erase(std::remove_if(key_text.begin(), key_text.end(), ::ispunct), key_text.end()); - // First check if word is in the dictionary - if (dictionary.count(key_text)) { - return dictionary.at(key_text); - } - // Convert input text to tensor std::vector input_tensors; std::vector input_ids = text_tokenizer->operator()(text, lang); @@ -274,9 +291,6 @@ namespace DeepPhonemizer { output_ids_vector[i] = std::distance(probabilities.begin(), max_prob_iter); } - // Convert output IDs to phonemes - std::vector phonemes = phoneme_tokenizer->decode(output_ids_vector, true); - - return phonemes; + return output_ids_vector; } } \ No newline at end of file