From 3d654d2b307e895f56704050a566c0231e761c7c Mon Sep 17 00:00:00 2001 From: danemadsen Date: Tue, 13 Aug 2024 16:39:58 +1000 Subject: [PATCH] simplify tokens --- include/babylon.hpp | 3 +-- src/phonemizer.cpp | 36 ++++++++++++++++++++---------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/include/babylon.hpp b/include/babylon.hpp index 0a42de8..7f0d364 100644 --- a/include/babylon.hpp +++ b/include/babylon.hpp @@ -15,8 +15,7 @@ namespace DeepPhonemizer { std::vector decode(const std::vector& sequence, bool remove_special_tokens = false) const; private: - std::unordered_map token_to_idx; - std::unordered_map idx_to_token; + std::vector tokens; int char_repeats; bool lowercase; bool append_start_end; diff --git a/src/phonemizer.cpp b/src/phonemizer.cpp index 26cabf5..a3e81d1 100644 --- a/src/phonemizer.cpp +++ b/src/phonemizer.cpp @@ -24,29 +24,32 @@ std::vector softmax(const std::vector& logits) { return probabilities; } +int get_token_index(const std::vector& tokens, const std::string& token) { + auto it = std::find(tokens.begin(), tokens.end(), token); + if (it != tokens.end()) { + return std::distance(tokens.begin(), it); + } + return -1; +} + namespace DeepPhonemizer { SequenceTokenizer::SequenceTokenizer(const std::vector& symbols, const std::vector& languages, int char_repeats, bool lowercase, bool append_start_end) : char_repeats(char_repeats), lowercase(lowercase), append_start_end(append_start_end), pad_token("_"), end_token("") { - pad_index = 0; - token_to_idx[pad_token] = pad_index; + tokens.push_back(pad_token); special_tokens.insert(pad_token); for (const auto& lang : languages) { std::string lang_token = "<" + lang + ">"; - token_to_idx[lang_token] = token_to_idx.size(); + tokens.push_back(lang_token); special_tokens.insert(lang_token); } - token_to_idx[end_token] = token_to_idx.size(); - end_index = token_to_idx[end_token]; + tokens.push_back(end_token); + end_index = tokens.size() - 1; for (const auto& symbol : symbols) { - token_to_idx[symbol] = token_to_idx.size(); - } - - for (const auto& pair : token_to_idx) { - idx_to_token[pair.second] = pair.first; + tokens.push_back(symbol); } } @@ -59,16 +62,17 @@ namespace DeepPhonemizer { std::vector sequence; for (char c : processed_sentence) { std::string symbol(1, c); - auto it = token_to_idx.find(symbol); - if (it != token_to_idx.end()) { + auto index = get_token_index(tokens, symbol); + if (index != -1) { for (int i = 0; i < char_repeats; ++i) { - sequence.push_back(it->second); + sequence.push_back(index); } } } if (append_start_end) { - sequence.insert(sequence.begin(), token_to_idx.at("<" + language + ">")); + auto index = get_token_index(tokens, "<" + language + ">"); + sequence.insert(sequence.begin(), index); sequence.push_back(end_index); } @@ -114,10 +118,10 @@ namespace DeepPhonemizer { if (token == end_index) { break; } - if (remove_special_tokens && special_tokens.count(idx_to_token.at(token))) { + if (remove_special_tokens && special_tokens.count(tokens[token])) { continue; } - decoded.push_back(idx_to_token.at(token)); + decoded.push_back(tokens[token]); } return decoded;