From 23935d02f8c27a28446d5c6d99a5669a85ddf6a0 Mon Sep 17 00:00:00 2001 From: danemadsen Date: Tue, 13 Aug 2024 16:28:44 +1000 Subject: [PATCH] simplify --- include/babylon.hpp | 3 --- src/phonemizer.cpp | 15 +++------------ 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/include/babylon.hpp b/include/babylon.hpp index aa432fd..0a42de8 100644 --- a/include/babylon.hpp +++ b/include/babylon.hpp @@ -25,9 +25,6 @@ namespace DeepPhonemizer { std::string pad_token; std::string end_token; std::unordered_set special_tokens; - - int get_start_index(const std::string& language) const; - std::string make_start_token(const std::string& language) const; }; class Session { diff --git a/src/phonemizer.cpp b/src/phonemizer.cpp index 011216f..26cabf5 100644 --- a/src/phonemizer.cpp +++ b/src/phonemizer.cpp @@ -20,7 +20,7 @@ std::vector softmax(const std::vector& logits) { for (size_t i = 0; i < logits.size(); ++i) { probabilities[i] = std::exp(logits[i] - max_logit) / sum; } - + return probabilities; } @@ -33,7 +33,7 @@ namespace DeepPhonemizer { special_tokens.insert(pad_token); for (const auto& lang : languages) { - std::string lang_token = make_start_token(lang); + std::string lang_token = "<" + lang + ">"; token_to_idx[lang_token] = token_to_idx.size(); special_tokens.insert(lang_token); } @@ -68,7 +68,7 @@ namespace DeepPhonemizer { } if (append_start_end) { - sequence.insert(sequence.begin(), get_start_index(language)); + sequence.insert(sequence.begin(), token_to_idx.at("<" + language + ">")); sequence.push_back(end_index); } @@ -123,15 +123,6 @@ namespace DeepPhonemizer { return decoded; } - int SequenceTokenizer::get_start_index(const std::string& language) const { - std::string lang_token = make_start_token(language); - return token_to_idx.at(lang_token); - } - - std::string SequenceTokenizer::make_start_token(const std::string& language) const { - return "<" + language + ">"; - } - Session::Session(const std::string& model_path, const std::string language, const bool use_punctuation) { Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "DeepPhonemizer"); env.DisableTelemetryEvents();