Skip to content

Commit

Permalink
More edits to example auto training notebook [ckip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
dscripka committed Sep 5, 2023
1 parent 7cafd26 commit 8ad5248
Showing 1 changed file with 43 additions and 37 deletions.
80 changes: 43 additions & 37 deletions notebooks/automatic_model_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
"cells": [
{
"cell_type": "markdown",
"id": "f49227ca",
"id": "43d8967d",
"metadata": {},
"source": [
"# Introduction"
]
},
{
"cell_type": "markdown",
"id": "6483fd4d",
"id": "ffbb278d",
"metadata": {},
"source": [
"This notebook demonstrates how to train custom openWakeWord models using pre-defined datasets, an automated process for synthetic data generation/augmentation, and a custom training process. While not guaranteed to always produce the best performing model, the methods shown in this notebook often produce baseline models with relatively strong performance.\n",
Expand All @@ -29,15 +29,15 @@
},
{
"cell_type": "markdown",
"id": "67ea460d",
"id": "b1d34904",
"metadata": {},
"source": [
"# Environment Setup"
]
},
{
"cell_type": "markdown",
"id": "004ef4db",
"id": "10224159",
"metadata": {},
"source": [
"To begin, we'll need to install the requirements for training custom models. In particular, a relatively recent version of Pytorch and custom fork of the [piper-sample-generator](https://github.com/dscripka/piper-sample-generator) library for generating synthetic examples for the custom model.\n",
Expand All @@ -48,7 +48,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ad77a769",
"id": "2d59a6b5",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -69,7 +69,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "da85818d",
"id": "8f30acfd",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T13:42:01.183840Z",
Expand All @@ -95,15 +95,15 @@
},
{
"cell_type": "markdown",
"id": "658ccb2a",
"id": "a4bdf3fd",
"metadata": {},
"source": [
"# Download Data"
]
},
{
"cell_type": "markdown",
"id": "1e4450ab",
"id": "e002fd0d",
"metadata": {},
"source": [
"When training new openWakeWord models using the automated procedure, four specific types of data are required:\n",
Expand All @@ -124,7 +124,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ab3c1fce",
"id": "58e4811e",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T01:07:17.746749Z",
Expand All @@ -149,29 +149,34 @@
{
"cell_type": "code",
"execution_count": null,
"id": "b28717ef",
"id": "bd867817",
"metadata": {},
"outputs": [],
"source": [
"## Download noise and background audio\n",
"\n",
"# FSD50k Noise Dataset (warning, this can take 5 minutes to prepare when streaming)\n",
"# https://zenodo.org/record/4060432\n",
"# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)\n",
"# Download one part of the audioset .tar files, extract, and convert to 16khz\n",
"# For full-scale training, it's reccomended to download the entire dataset, and\n",
"# even combine with other background noise datasets (e.g., FSD50k, Freesound, etc.)\n",
"fname = \"bal_train09.tar\"\n",
"out_dir = f\"audioset/{fname}\"\n",
"link = \"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/\" + fname\n",
"!wget -O {out_dir} {link}\n",
"!cd audioset && tar -xvf bal_train09.tar\n",
"\n",
"output_dir = \"./fsd50k\"\n",
"if not os.path.exists(\"audioset\"):\n",
" os.mkdir(\"audioset\")\n",
"\n",
"output_dir = \"./audioset_16k\"\n",
"if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n",
"fsd50k_dataset = datasets.load_dataset(\"Fhrozen/FSD50k\", split=\"validation\", streaming=True) # ~40,000 files in this split\n",
"fsd50k_dataset = iter(fsd50k_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000)))\n",
"\n",
"n_total = 500 # use only 500 clips for this example notebook, recommend increasing for full-scale training\n",
"for i in tqdm(range(n_total)):\n",
" row = next(fsd50k_dataset)\n",
" name = row['audio']['path'].split('/')[-1]\n",
"audioset_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"audioset/audio\").glob(\"**/*.flac\")]})\n",
"audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n",
"for row in tqdm(audioset_dataset):\n",
" name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n",
" scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, row['audio']['array'])\n",
" i += 1\n",
" if i == n_total:\n",
" break\n",
"\n",
"# Free Music Archive dataset\n",
"# https://github.com/mdeff/fma\n",
Expand All @@ -185,7 +190,7 @@
"n_hours = 1 # use only 1 hour of clips for this example notebook, recommend increasing for full-scale training\n",
"for i in tqdm(range(n_hours*3600//30)): # this works because the FMA dataset is all 30 second clips\n",
" row = next(fma_dataset)\n",
" name = row['audio']['path'].split('/')[-1]\n",
" name = row['audio']['path'].split('/')[-1].replace(\".mp3\", \".wav\")\n",
" scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, row['audio']['array'])\n",
" i += 1\n",
" if i == n_hours*3600//30:\n",
Expand All @@ -195,30 +200,31 @@
{
"cell_type": "code",
"execution_count": null,
"id": "b5bbe225",
"id": "203df175",
"metadata": {},
"outputs": [],
"source": [
"# Download pre-computed openWakeWord features for training and validation\n",
"\n",
"# training set (~2,000 hours)\n",
"!wget https://huggingface.co/datasets/davidscripka/openwakeword_features/blob/main/openwakeword_features_ACAV100M_2000_hrs_16bit.npy\n",
"# training set (~2,000 hours from the ACAV100M Dataset)\n",
"# See https://huggingface.co/datasets/davidscripka/openwakeword_features for more information\n",
"!wget https://huggingface.co/datasets/davidscripka/openwakeword_features/resolve/main/openwakeword_features_ACAV100M_2000_hrs_16bit.npy\n",
"\n",
"# validation set for false positive rate estimation (~11 hours)\n",
"!wget https://huggingface.co/datasets/davidscripka/openwakeword_features/blob/main/validation_set_features.npy"
"!wget https://huggingface.co/datasets/davidscripka/openwakeword_features/resolve/main/validation_set_features.npy"
]
},
{
"cell_type": "markdown",
"id": "9265dbcf",
"id": "290865f2",
"metadata": {},
"source": [
"# Define Training Configuration"
]
},
{
"cell_type": "markdown",
"id": "1e204f3e",
"id": "041ac7e6",
"metadata": {},
"source": [
"For automated model training openWakeWord uses a specially designed training script and a [YAML](https://yaml.org/) configuration file that defines all of the information required for training a new wake word/phrase detection model.\n",
Expand All @@ -237,7 +243,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c0fea7ab",
"id": "fc70e5ab",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T18:11:33.893397Z",
Expand All @@ -254,7 +260,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "fbe862c5",
"id": "bc278709",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T15:07:00.859210Z",
Expand Down Expand Up @@ -282,15 +288,15 @@
},
{
"cell_type": "markdown",
"id": "c7460b02",
"id": "b46c080b",
"metadata": {},
"source": [
"# Train the Model"
]
},
{
"cell_type": "markdown",
"id": "c1de9710",
"id": "55bc110f",
"metadata": {},
"source": [
"With the data downloaded and training configuration set, we can now start training the model. We'll do this in parts to better illustrate the sequence, but you can also execute every step at once for a fully automated process."
Expand All @@ -299,7 +305,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "85115fc9",
"id": "5c0ffb3a",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T13:50:08.803326Z",
Expand All @@ -319,7 +325,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9efa15af",
"id": "190dea93",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T13:56:08.781018Z",
Expand All @@ -336,7 +342,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "916b511a",
"id": "6ec6e791",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T15:11:14.742260Z",
Expand All @@ -352,7 +358,7 @@
},
{
"cell_type": "markdown",
"id": "4925229c",
"id": "6deb1fda",
"metadata": {},
"source": [
"After the model finishes training, the auto training script will automatically convert it to ONNX and tflite versions, saving them as `<model_name>.onnx/tflite` in the present working directory, where `<model_name>` is defined in the YAML training config file.\n",
Expand Down

0 comments on commit 8ad5248

Please sign in to comment.