Skip to content

Commit

Permalink
work on punct tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Aug 13, 2024
1 parent 3d654d2 commit f026451
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
4 changes: 2 additions & 2 deletions include/babylon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ namespace DeepPhonemizer {
~Session();

std::vector<std::string> g2p(const std::string& text);
std::vector<int64_t> g2p_tokens(const std::string& text);

private:
std::string lang;
bool punctuation;
Ort::Session* session;
SequenceTokenizer* text_tokenizer;
SequenceTokenizer* phoneme_tokenizer;
std::unordered_map<std::string, std::vector<std::string>> dictionary;

std::vector<std::string> g2p_internal(const std::string& text);
std::vector<int64_t> g2p_tokens_internal(const std::string& text);
};

std::vector<std::string> clean_text(const std::string& text);
Expand Down
46 changes: 30 additions & 16 deletions src/phonemizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
const std::array<const char *, 1> input_names = {"text"};
const std::array<const char *, 1> output_names = {"output"};

const std::vector<std::string> punctuation_symbols = {
".",
",",
":",
";",
"!",
"?"
};

std::vector<float> softmax(const std::vector<float>& logits) {
float max_logit = *std::max_element(logits.begin(), logits.end());
std::vector<float> probabilities(logits.size());
Expand Down Expand Up @@ -51,6 +60,10 @@ namespace DeepPhonemizer {
for (const auto& symbol : symbols) {
tokens.push_back(symbol);
}

for (const auto& symbol : punctuation_symbols) {
tokens.push_back(symbol);
}
}

std::vector<int64_t> SequenceTokenizer::operator()(const std::string& sentence, const std::string& language) const {
Expand Down Expand Up @@ -189,41 +202,45 @@ namespace DeepPhonemizer {
}

std::vector<std::string> Session::g2p(const std::string& text) {
// Convert input text to phonemes
std::vector<int64_t> phoneme_tokens = g2p_tokens(text);

// Decode the phoneme tokens
return phoneme_tokenizer->decode(phoneme_tokens, true);;
}

std::vector<int64_t> Session::g2p_tokens(const std::string& text) {
// Clean the input text
std::vector<std::string> words = clean_text(text);

// Convert each word to phonemes
std::vector<std::string> phonemes;
std::vector<int64_t> phoneme_ids;
for (const auto& word : words) {
std::vector<std::string> word_phonemes = g2p_internal(word);
std::vector<int64_t> word_phoneme_ids = g2p_tokens_internal(word);

phonemes.insert(phonemes.end(), word_phonemes.begin(), word_phonemes.end());
phoneme_ids.insert(phoneme_ids.end(), word_phoneme_ids.begin(), word_phoneme_ids.end());

if (punctuation) {
// Check if the word ends with punctuation
if (std::ispunct(word.back())) {
phonemes.push_back(std::string(1, word.back()));
auto punct_token = phoneme_tokenizer->operator()(std::string(1, word.back()), lang);
phoneme_ids.insert(phoneme_ids.end(), punct_token.begin(), punct_token.end());
}
}

phonemes.push_back(" ");
phoneme_ids.push_back(0);
}

return phonemes;
return phoneme_ids;
}

std::vector<std::string> Session::g2p_internal(const std::string& text) {
std::vector<int64_t> Session::g2p_tokens_internal(const std::string& text) {
// Check if the input text is longer than one character
std::string key_text = text;
std::transform(key_text.begin(), key_text.end(), key_text.begin(), ::tolower);

key_text.erase(std::remove_if(key_text.begin(), key_text.end(), ::ispunct), key_text.end());

// First check if word is in the dictionary
if (dictionary.count(key_text)) {
return dictionary.at(key_text);
}

// Convert input text to tensor
std::vector<Ort::Value> input_tensors;
std::vector<int64_t> input_ids = text_tokenizer->operator()(text, lang);
Expand Down Expand Up @@ -274,9 +291,6 @@ namespace DeepPhonemizer {
output_ids_vector[i] = std::distance(probabilities.begin(), max_prob_iter);
}

// Convert output IDs to phonemes
std::vector<std::string> phonemes = phoneme_tokenizer->decode(output_ids_vector, true);

return phonemes;
return output_ids_vector;
}
}

0 comments on commit f026451

Please sign in to comment.