Skip to content

Commit

Permalink
option to not use dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Aug 13, 2024
1 parent 7bdc946 commit ba61245
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 25 deletions.
10 changes: 9 additions & 1 deletion include/babylon.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@
extern "C" {
#endif

#include <stdbool.h>

#ifdef WIN32
#define BABYLON_EXPORT __declspec(dllexport)
#else
#define BABYLON_EXPORT __attribute__((visibility("default"))) __attribute__((used))
#endif

BABYLON_EXPORT int babylon_g2p_init(const char* model_path, const char* language, int use_punctuation);
typedef struct {
const char* language;
bool use_punctuation;
bool use_dictionary;
} babylon_g2p_options;

BABYLON_EXPORT int babylon_g2p_init(const char* model_path, babylon_g2p_options* options);

BABYLON_EXPORT char* babylon_g2p(const char* text);

Expand Down
7 changes: 6 additions & 1 deletion include/babylon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ namespace DeepPhonemizer {

class Session {
public:
Session(const std::string& model_path, const std::string language = "en_us", const bool use_punctuation = false);
Session(
const std::string& model_path,
const std::string language = "en_us",
const bool use_punctuation = false,
const bool use_dictionary = false
);
~Session();

std::vector<std::string> g2p(const std::string& text);
Expand Down
4 changes: 2 additions & 2 deletions src/babylon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ static DeepPhonemizer::Session* dp;
static Vits::Session* vits;

extern "C" {
BABYLON_EXPORT int babylon_g2p_init(const char* model_path, const char* language, int use_punctuation) {
BABYLON_EXPORT int babylon_g2p_init(const char* model_path, babylon_g2p_options* options) {
try {
dp = new DeepPhonemizer::Session(model_path, language, use_punctuation);
dp = new DeepPhonemizer::Session(model_path, options->language, options->use_punctuation);
return 0;
}
catch (const std::exception& e) {
Expand Down
4 changes: 0 additions & 4 deletions src/cleaners.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,6 @@ namespace DeepPhonemizer {
}
}

if (!word.empty()) {
words.push_back(word);
}

return words;
}
}
43 changes: 26 additions & 17 deletions src/phonemizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,29 @@ namespace DeepPhonemizer {
return probabilities;
}

Session::Session(const std::string& model_path, const std::string language, const bool use_punctuation) {
std::unordered_map<std::string, std::vector<std::string>> parse_dictionary(const std::string& dictionary_str) {
std::unordered_map<std::string, std::vector<std::string>> dictionary;

std::istringstream dictionary_stream(dictionary_str);
std::string line;
while (std::getline(dictionary_stream, line)) {
std::stringstream line_stream(line);
std::string word;
line_stream >> word;

std::vector<std::string> phonemes;
std::string phoneme;
while (line_stream >> phoneme) {
phonemes.push_back(phoneme);
}

dictionary[word] = phonemes;
}

return dictionary;
}

Session::Session(const std::string& model_path, const std::string language, const bool use_punctuation, const bool use_dictionary) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "DeepPhonemizer");
env.DisableTelemetryEvents();

Expand Down Expand Up @@ -170,22 +192,9 @@ namespace DeepPhonemizer {
phoneme_symbols.push_back(phoneme_symbol_buffer);
}

std::string dictonary_str = model_metadata.LookupCustomMetadataMapAllocated("dictionary", allocator).get();

std::istringstream dictionary_stream(dictonary_str);
std::string line;
while (std::getline(dictionary_stream, line)) {
std::stringstream line_stream(line);
std::string word;
line_stream >> word;

std::vector<std::string> phonemes;
std::string phoneme;
while (line_stream >> phoneme) {
phonemes.push_back(phoneme);
}

dictionary[word] = phonemes;
if (use_dictionary) {
std::string dictonary_str = model_metadata.LookupCustomMetadataMapAllocated("dictionary", allocator).get();
dictionary = parse_dictionary(dictonary_str);
}

int char_repeats = model_metadata.LookupCustomMetadataMapAllocated("char_repeats", allocator).get()[0] - '0';
Expand Down

0 comments on commit ba61245

Please sign in to comment.