Skip to content

Commit

Permalink
tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Aug 14, 2024
1 parent f026451 commit 4114efa
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 51 deletions.
4 changes: 3 additions & 1 deletion include/babylon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ namespace DeepPhonemizer {
public:
SequenceTokenizer(const std::vector<std::string>& symbols, const std::vector<std::string>& languages, int char_repeats, bool lowercase = true, bool append_start_end = true);
std::vector<int64_t> operator()(const std::string& sentence, const std::string& language) const;
std::vector<std::string> decode(const std::vector<int64_t>& sequence, bool remove_special_tokens = false) const;
std::vector<std::string> decode(const std::vector<int64_t>& sequence) const;
std::vector<int64_t> clean(const std::vector<int64_t>& sequence) const;
int64_t get_token(const std::string& token) const;

private:
std::vector<std::string> tokens;
Expand Down
Binary file modified models/deep_phonemizer.onnx
Binary file not shown.
2 changes: 1 addition & 1 deletion scripts/deep_phonemizer/dp_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, text, phonemes=None, start_index=None):
metadata = {
"languages": "de en_us",
"text_symbols": "a b c d e f g h i j k l m n o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z ä ö ü Ä Ö Ü ß",
"phoneme_symbols": "a b d e f g h i j k l m n o p r s t u v w x y z æ ç ð ø ŋ œ ɐ ɑ ɔ ə ɛ ɜ ɹ ɡ ɪ ʁ ʃ ʊ ʌ ʏ ʒ ʔ ' ˌ ː ̃ ̍ ̥ ̩ ̯ ͡ θ",
"phoneme_symbols": "a b d e f g h i j k l m n o p r s t u v w x y z æ ç ð ø ŋ œ ɐ ɑ ɔ ə ɛ ɜ ɹ ɡ ɪ ʁ ʃ ʊ ʌ ʏ ʒ ʔ ' ˌ ː ̃ ̍ ̥ ̩ ̯ ͡ θ . , : ; ? !",
"char_repeats": "3" if isinstance(model, ForwardTransformer) else "1",
"lowercase": "1"
}
Expand Down
101 changes: 52 additions & 49 deletions src/phonemizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@
const std::array<const char *, 1> input_names = {"text"};
const std::array<const char *, 1> output_names = {"output"};

const std::vector<std::string> punctuation_symbols = {
".",
",",
":",
";",
"!",
"?"
};

std::vector<float> softmax(const std::vector<float>& logits) {
float max_logit = *std::max_element(logits.begin(), logits.end());
std::vector<float> probabilities(logits.size());
Expand All @@ -33,17 +24,9 @@ std::vector<float> softmax(const std::vector<float>& logits) {
return probabilities;
}

int get_token_index(const std::vector<std::string>& tokens, const std::string& token) {
auto it = std::find(tokens.begin(), tokens.end(), token);
if (it != tokens.end()) {
return std::distance(tokens.begin(), it);
}
return -1;
}

namespace DeepPhonemizer {
SequenceTokenizer::SequenceTokenizer(const std::vector<std::string>& symbols, const std::vector<std::string>& languages, int char_repeats, bool lowercase, bool append_start_end)
: char_repeats(char_repeats), lowercase(lowercase), append_start_end(append_start_end), pad_token("_"), end_token("<end>") {
: char_repeats(char_repeats), lowercase(lowercase), append_start_end(append_start_end), pad_token(" "), end_token("<end>") {

tokens.push_back(pad_token);
special_tokens.insert(pad_token);
Expand All @@ -60,10 +43,6 @@ namespace DeepPhonemizer {
for (const auto& symbol : symbols) {
tokens.push_back(symbol);
}

for (const auto& symbol : punctuation_symbols) {
tokens.push_back(symbol);
}
}

std::vector<int64_t> SequenceTokenizer::operator()(const std::string& sentence, const std::string& language) const {
Expand All @@ -75,7 +54,7 @@ namespace DeepPhonemizer {
std::vector<int64_t> sequence;
for (char c : processed_sentence) {
std::string symbol(1, c);
auto index = get_token_index(tokens, symbol);
auto index = get_token(symbol);
if (index != -1) {
for (int i = 0; i < char_repeats; ++i) {
sequence.push_back(index);
Expand All @@ -84,7 +63,7 @@ namespace DeepPhonemizer {
}

if (append_start_end) {
auto index = get_token_index(tokens, "<" + language + ">");
auto index = get_token("<" + language + ">");
sequence.insert(sequence.begin(), index);
sequence.push_back(end_index);
}
Expand All @@ -102,44 +81,65 @@ namespace DeepPhonemizer {
return sequence;
}

std::vector<std::string> SequenceTokenizer::decode(const std::vector<int64_t>& sequence, bool remove_special_tokens) const {
std::vector<int64_t> pruned_sequence = sequence;
pruned_sequence.erase(
std::remove(pruned_sequence.begin(), pruned_sequence.end(), pad_index),
pruned_sequence.end()
);

std::vector<std::string> SequenceTokenizer::decode(const std::vector<int64_t>& sequence) const {
std::vector<int64_t> processed_sequence;
if (append_start_end) {
processed_sequence.push_back(pruned_sequence.front());
for (size_t i = 1; i < pruned_sequence.size() - 1; i += char_repeats) {
processed_sequence.push_back(pruned_sequence[i]);
processed_sequence.push_back(sequence.front());
for (size_t i = 1; i < sequence.size() - 1; i += char_repeats) {
processed_sequence.push_back(sequence[i]);
}
processed_sequence.push_back(pruned_sequence.back());
processed_sequence.push_back(sequence.back());
} else {
for (size_t i = 0; i < pruned_sequence.size(); i += char_repeats) {
processed_sequence.push_back(pruned_sequence[i]);
for (size_t i = 0; i < sequence.size(); i += char_repeats) {
processed_sequence.push_back(sequence[i]);
}
}

// Remove consecutive duplicate tokens
auto last = std::unique(processed_sequence.begin(), processed_sequence.end());
processed_sequence.erase(last, processed_sequence.end());

std::vector<std::string> decoded;
for (int64_t token : processed_sequence) {
if (token == end_index) {
break;
}
if (remove_special_tokens && special_tokens.count(tokens[token])) {
continue;
}
decoded.push_back(tokens[token]);
}

return decoded;
}

std::vector<int64_t> SequenceTokenizer::clean(const std::vector<int64_t>& sequence) const {
std::vector<int64_t> processed_sequence = sequence;

// remove all special tokens from the sequence
for (auto token : special_tokens) {
auto special_token_index = get_token(token);
if (special_token_index != -1) {
processed_sequence.erase(std::remove(processed_sequence.begin(), processed_sequence.end(), special_token_index), processed_sequence.end());
}
}

// extract everything between the start and end tokens
auto end = std::find(processed_sequence.begin(), processed_sequence.end(), end_index);
if (end != processed_sequence.end()) {
processed_sequence.erase(end, processed_sequence.end());
}

// Remove consecutive duplicate tokens
auto last = std::unique(processed_sequence.begin(), processed_sequence.end());
processed_sequence.erase(last, processed_sequence.end());

return processed_sequence;
}

int64_t SequenceTokenizer::get_token(const std::string& token) const {
auto it = std::find(tokens.begin(), tokens.end(), token);

if (it != tokens.end()) {
return std::distance(tokens.begin(), it);
}

return -1;
}

Session::Session(const std::string& model_path, const std::string language, const bool use_punctuation) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "DeepPhonemizer");
env.DisableTelemetryEvents();
Expand Down Expand Up @@ -206,7 +206,7 @@ namespace DeepPhonemizer {
std::vector<int64_t> phoneme_tokens = g2p_tokens(text);

// Decode the phoneme tokens
return phoneme_tokenizer->decode(phoneme_tokens, true);;
return phoneme_tokenizer->decode(phoneme_tokens);
}

std::vector<int64_t> Session::g2p_tokens(const std::string& text) {
Expand All @@ -217,14 +217,17 @@ namespace DeepPhonemizer {
std::vector<int64_t> phoneme_ids;
for (const auto& word : words) {
std::vector<int64_t> word_phoneme_ids = g2p_tokens_internal(word);

std::vector<int64_t> cleaned_word_phoneme_ids = phoneme_tokenizer->clean(word_phoneme_ids);

phoneme_ids.insert(phoneme_ids.end(), word_phoneme_ids.begin(), word_phoneme_ids.end());
phoneme_ids.insert(phoneme_ids.end(), cleaned_word_phoneme_ids.begin(), cleaned_word_phoneme_ids.end());

if (punctuation) {
auto back_token = phoneme_tokenizer->get_token(std::string(1, word.back()));

// Check if the word ends with punctuation
if (std::ispunct(word.back())) {
auto punct_token = phoneme_tokenizer->operator()(std::string(1, word.back()), lang);
phoneme_ids.insert(phoneme_ids.end(), punct_token.begin(), punct_token.end());
if (std::ispunct(word.back()) && back_token != -1) {
phoneme_ids.push_back(back_token);
}
}

Expand Down

0 comments on commit 4114efa

Please sign in to comment.