diff --git a/PlatformIO/lib/esp_sr/libc_speech_features.a b/PlatformIO/lib/esp_sr/libc_speech_features.a new file mode 100644 index 0000000..4e665a5 Binary files /dev/null and b/PlatformIO/lib/esp_sr/libc_speech_features.a differ diff --git a/PlatformIO/lib/esp_sr/libdl_lib.a b/PlatformIO/lib/esp_sr/libdl_lib.a new file mode 100644 index 0000000..65614cd Binary files /dev/null and b/PlatformIO/lib/esp_sr/libdl_lib.a differ diff --git a/PlatformIO/lib/esp_sr/libnn_model_alexa_wn3.a b/PlatformIO/lib/esp_sr/libnn_model_alexa_wn3.a new file mode 100644 index 0000000..1889573 Binary files /dev/null and b/PlatformIO/lib/esp_sr/libnn_model_alexa_wn3.a differ diff --git a/PlatformIO/lib/esp_sr/libwakenet.a b/PlatformIO/lib/esp_sr/libwakenet.a new file mode 100644 index 0000000..4c3acd1 Binary files /dev/null and b/PlatformIO/lib/esp_sr/libwakenet.a differ diff --git a/PlatformIO/platformio.ini b/PlatformIO/platformio.ini index b6990b4..331f213 100644 --- a/PlatformIO/platformio.ini +++ b/PlatformIO/platformio.ini @@ -8,15 +8,7 @@ ; Please visit documentation for the other options and examples ; https://docs.platformio.org/page/projectconf.html -[env:esp32dev] -platform = espressif32@1.9.0 -upload_protocol = espota -board = esp32dev -framework = arduino -board_build.partitions = ../OTABuilder/partitions_two_ota.csv -; MatrixVoice ESP32 LAN name or IP, should match HOSTNAME in build_flags -upload_port = '192.168.43.140' - +[common] build_flags = '-DFIXED_POINT=1' '-DOUTSIDE_SPEEX=1' @@ -31,6 +23,21 @@ build_flags = '-DMQTT_USER="username"' ; Change to your MQTT username '-DMQTT_PASS="password"' ; Change to your MQTT password '-DMQTT_MAX_PACKET_SIZE=2000' ; This is required, otherwise audiopackets will not be send + '-lnn_model_alexa_wn3' + '-Llib/esp_sr' + '-lwakenet' + '-ldl_lib' + '-lc_speech_features' + +[env:esp32dev] +platform = espressif32@1.9.0 +upload_protocol = espota +board = esp32dev +framework = arduino +board_build.partitions = ../OTABuilder/partitions_two_ota.csv +; MatrixVoice ESP32 LAN name or IP, should match HOSTNAME in build_flags +upload_port = '192.168.43.140' +build_flags = ${common.build_flags} ; MatrixVoice OTA password (auth), should match hashed password (OTA_PASS_HASH) in build_flags upload_flags = diff --git a/PlatformIO/src/MatrixVoiceAudioServer.cpp b/PlatformIO/src/MatrixVoiceAudioServer.cpp index 2039dee..ec35e15 100644 --- a/PlatformIO/src/MatrixVoiceAudioServer.cpp +++ b/PlatformIO/src/MatrixVoiceAudioServer.cpp @@ -85,8 +85,14 @@ extern "C" { #include "freertos/event_groups.h" #include "freertos/timers.h" #include "speex_resampler.h" + #include "esp_wn_iface.h" } +extern const esp_wn_iface_t esp_sr_wakenet3_quantized; +extern const model_coeff_getter_t get_coeff_wakeNet3_model_float; +#define WAKENET_COEFF get_coeff_wakeNet3_model_float +#define WAKENET_MODEL esp_sr_wakenet3_quantized + /* ************************************************************************* * DEFINES AND GLOBALS * ************************************************************************ */ @@ -96,6 +102,9 @@ extern "C" { #define DATA_CHUNK_ID 0x61746164 #define FMT_CHUNK_ID 0x20746d66 +static const esp_wn_iface_t *wakenet = &WAKENET_MODEL; +static const model_coeff_getter_t *model_coeff_getter = &WAKENET_COEFF; + // These parameters enable you to select the default value for output enum { AMP_OUT_SPEAKERS = 0, @@ -169,8 +178,10 @@ bool hotword_detected = false; bool isUpdateInProgess = false; bool streamingBytes = false; bool endStream = false; +bool localHotwordDetection = false; bool DEBUG = false; std::string finishedMsg = ""; +std::string detectMsg = ""; int message_count; int CHUNK = 256; // set to multiplications of 256, voice return a set of 256 int chunkValues[] = {32, 64, 128, 256, 512, 1024}; @@ -213,6 +224,7 @@ std::string playBytesStreamingTopic = std::string("hermes/audioServer/") + SITEI std::string rhasspyWakeTopic = std::string("rhasspy/+/transition/+"); std::string toggleOffTopic = "hermes/hotword/toggleOff"; std::string toggleOnTopic = "hermes/hotword/toggleOn"; +std::string hotwordDetectedTopic = "hermes/hotword/default/detected"; std::string everloopTopic = SITEID + std::string("/everloop"); std::string debugTopic = SITEID + std::string("/debug"); std::string audioTopic = SITEID + std::string("/audio"); @@ -507,6 +519,9 @@ void onMqttMessage(char *topic, char *payload, AsyncMqttClientMessageProperties if (root.containsKey("gain")) { mics.SetGain((int)root["gain"]); } + if (root.containsKey("hotword")) { + localHotwordDetection = (root["hotword"] == "local") ? true : false; + } } else { publishDebug(err.c_str()); } @@ -573,6 +588,7 @@ void onMqttMessage(char *topic, char *payload, AsyncMqttClientMessageProperties AUDIOSTREAM TASK, USES SYNCED MQTT CLIENT * ************************************************************************ */ void Audiostream(void *p) { + model_iface_data_t *model_data = wakenet->create(model_coeff_getter, DET_MODE_90); while (1) { // Wait for the bit before updating. Do not clear in the wait exit; (first false) xEventGroupWaitBits(audioGroup, STREAM, false, false, portMAX_DELAY); @@ -589,24 +605,46 @@ void Audiostream(void *p) { uint8_t voicemapped[CHUNK * WIDTH]; uint8_t payload[sizeof(header) + (CHUNK * WIDTH)]; - // Message count is the Matrix NumberOfSamples divided by the - // framerate of Snips. This defaults to 512 / 256 = 2. If you - // lower the framerate, the AudioServer has to send more - // wavefile because the NumOfSamples is a fixed number - for (int i = 0; i < message_count; i++) { - for (uint32_t s = CHUNK * i; s < CHUNK * (i + 1); s++) { - voicebuffer[s - (CHUNK * i)] = mics.Beam(s); + if (!hotword_detected && localHotwordDetection) { + + int16_t voicebuffer_wk[CHUNK * WIDTH]; + for (uint32_t s = 0; s < CHUNK * WIDTH; s++) { + voicebuffer_wk[s] = mics.Beam(s); + } + + int r = wakenet->detect(model_data, voicebuffer_wk); + if (r > 0) { + detectMsg = std::string("{\"siteId\":\"") + SITEID + std::string("\"}"); + asyncClient.publish(hotwordDetectedTopic.c_str(), 0, false, detectMsg.c_str()); + hotword_detected = true; + publishDebug("Hotword Detected"); + } + //simulate message for leds + for (int i = 0; i < message_count; i++) { + streamMessageCount++; + } + } + + if (hotword_detected || !localHotwordDetection) { + // Message count is the Matrix NumberOfSamples divided by the + // framerate of Snips. This defaults to 512 / 256 = 2. If you + // lower the framerate, the AudioServer has to send more + // wavefile because the NumOfSamples is a fixed number + for (int i = 0; i < message_count; i++) { + for (uint32_t s = CHUNK * i; s < CHUNK * (i + 1); s++) { + voicebuffer[s - (CHUNK * i)] = mics.Beam(s); + } + // voicebuffer will hold 256 samples of 2 bytes, but we need + // it as 1 byte We do a memcpy, because I need to add the + // wave header as well + memcpy(voicemapped, voicebuffer, CHUNK * WIDTH); + + // Add the wave header + memcpy(payload, &header, sizeof(header)); + memcpy(&payload[sizeof(header)], voicemapped,sizeof(voicemapped)); + audioServer.publish(audioFrameTopic.c_str(),(uint8_t *)payload, sizeof(payload)); + streamMessageCount++; } - // voicebuffer will hold 256 samples of 2 bytes, but we need - // it as 1 byte We do a memcpy, because I need to add the - // wave header as well - memcpy(voicemapped, voicebuffer, CHUNK * WIDTH); - - // Add the wave header - memcpy(payload, &header, sizeof(header)); - memcpy(&payload[sizeof(header)], voicemapped,sizeof(voicemapped)); - audioServer.publish(audioFrameTopic.c_str(),(uint8_t *)payload, sizeof(payload)); - streamMessageCount++; } } xSemaphoreGive(wbSemaphore); // Now free or "Give" the Serial Port for others. diff --git a/PlatformIO/src/esp_wn_iface.h b/PlatformIO/src/esp_wn_iface.h new file mode 100644 index 0000000..8082681 --- /dev/null +++ b/PlatformIO/src/esp_wn_iface.h @@ -0,0 +1,122 @@ +#pragma once +#include "stdint.h" +#include "dl_lib_coefgetter_if.h" + +//Opaque model data container +typedef struct model_iface_data_t model_iface_data_t; + +//Set wake words recognition operating mode +//The probability of being wake words is increased with increasing mode, +//As a consequence also the false alarm rate goes up +typedef enum { + DET_MODE_90 = 0, //Normal, response accuracy rate about 90% + DET_MODE_95 //Aggressive, response accuracy rate about 95% +} det_mode_t; + +typedef struct { + int wake_word_num; //The number of all wake words + char **wake_word_list; //The name list of wake words +} wake_word_info_t; + +/** + * @brief Easy function type to initialze a model instance with a detection mode and specified wake word coefficient + * + * @param det_mode The wake words detection mode to trigger wake words, DET_MODE_90 or DET_MODE_95 + * @param model_coeff The specified wake word model coefficient + * @returns Handle to the model data + */ +typedef model_iface_data_t* (*esp_wn_iface_op_create_t)(const model_coeff_getter_t *model_coeff, det_mode_t det_mode); + + +/** + * @brief Callback function type to fetch the amount of samples that need to be passed to the detect function + * + * Every speech recognition model processes a certain number of samples at the same time. This function + * can be used to query that amount. Note that the returned amount is in 16-bit samples, not in bytes. + * + * @param model The model object to query + * @return The amount of samples to feed the detect function + */ +typedef int (*esp_wn_iface_op_get_samp_chunksize_t)(model_iface_data_t *model); + + +/** + * @brief Get the sample rate of the samples to feed to the detect function + * + * @param model The model object to query + * @return The sample rate, in hz + */ +typedef int (*esp_wn_iface_op_get_samp_rate_t)(model_iface_data_t *model); + +/** + * @brief Get the number of wake words + * + * @param model The model object to query + * @returns the number of wake words + */ +typedef int (*esp_wn_iface_op_get_word_num_t)(model_iface_data_t *model); + +/** + * @brief Get the name of wake word by index + * + * @Warning The index of wake word start with 1 + + * @param model The model object to query + * @param word_index The index of wake word + * @returns the detection threshold + */ +typedef char* (*esp_wn_iface_op_get_word_name_t)(model_iface_data_t *model, int word_index); + +/** + * @brief Set the detection threshold to manually abjust the probability + * + * @param model The model object to query + * @param det_treshold The threshold to trigger wake words, the range of det_threshold is 0.5~0.9999 + * @param word_index The index of wake word + * @return 0: setting failed, 1: setting success + */ +typedef int (*esp_wn_iface_op_set_det_threshold_t)(model_iface_data_t *model, float det_threshold, int word_index); + +/** + * @brief Get the wake word detection threshold of different modes + * + * @param model The model object to query + * @param word_index The index of wake word + * @returns the detection threshold + */ +typedef float (*esp_wn_iface_op_get_det_threshold_t)(model_iface_data_t *model, int word_index); + +/** + * @brief Feed samples of an audio stream to the keyword detection model and detect if there is a keyword found. + * + * @Warning The index of wake word start with 1, 0 means no wake words is detected. + * + * @param model The model object to query + * @param samples An array of 16-bit signed audio samples. The array size used can be queried by the + * get_samp_chunksize function. + * @return The index of wake words, return 0 if no wake word is detected, else the index of the wake words. + */ +typedef int (*esp_wn_iface_op_detect_t)(model_iface_data_t *model, int16_t *samples); + +/** + * @brief Destroy a speech recognition model + * + * @param model Model object to destroy + */ +typedef void (*esp_wn_iface_op_destroy_t)(model_iface_data_t *model); + + +/** + * This structure contains the functions used to do operations on a wake word detection model. + */ +typedef struct { + esp_wn_iface_op_create_t create; + esp_wn_iface_op_get_samp_chunksize_t get_samp_chunksize; + esp_wn_iface_op_get_samp_rate_t get_samp_rate; + esp_wn_iface_op_get_word_num_t get_word_num; + esp_wn_iface_op_get_word_name_t get_word_name; + esp_wn_iface_op_set_det_threshold_t set_det_threshold; + esp_wn_iface_op_get_det_threshold_t get_det_threshold; + esp_wn_iface_op_detect_t detect; + esp_wn_iface_op_destroy_t destroy; +} esp_wn_iface_t;