diff --git a/docs/RECORDING.md b/docs/RECORDING.md index 27c18f65..dee89eea 100644 --- a/docs/RECORDING.md +++ b/docs/RECORDING.md @@ -5,26 +5,33 @@ In order to train a model, you need to record sounds first. You can do this by r ![Installing packages](media/settings-record.png) -This script will record sounds in seperate files of 30 milliseconds each and save them in your recordings folder ( data/recordings is the default place, which can be changed in the data/code/config.py file using the examples in lib/default_conifg.py ). +This script will record your microphone and save the detected areas inside of an SRT file. It will record in overlapping segments of 30 milliseconds. You have to be sure to record as little noise as possible. For example, if you are recording a bell sound, it is imperative that you only record that sound. -If you accidentally recorded a different sound, you can always delete the specific file from the recordings directory. + ![Installing packages](media/settings-record-progress.png) -In order to make sure you only record the sound you want to record, you can alter the power setting at the start. I usually choose a value between 1000 and 2000. -You can also trim out stuff below a specific frequency value. Neither the intensity, power or the frequency values I am using isn't actually an SI unit like dB or Hz, just some rough calculations which will go up when the loudness or frequency goes up. +During the recording, you can also pause the recording using SPACE or quit it using ESC. +If you feel a sneeze coming up, or a car passes by, you can press these keys to make sure you don't have to remove data. +If you accidentally did record a different sound, you can always press BACKSPACE or - to remove some data from the recording. -During the recording, you can also pause the recording using SPACE or quit it using ESC. If you feel a sneeze coming up, or a car passes by, you can press these keys to make sure you don't have to prune away a lot of files. +You can look at the 'Recorded' part during the recording session to see how much of your sound has been detected. ### Amount of data needed -I found that you need around 30 seconds of recorded sound, roughly 1000 samples, to get a working recognition of a specific sound. Depending on the noise it would take between a minute and two minutes to record the sounds ( there are less samples to pick from with short sounds like clicks, whereas longer sounds like vowels give more samples ). -You will start getting diminishing returns past two and a half minutes of recorded sound ( 5000 samples ), but the returns are still there. As of the moment of this writing, I used 15000 samples for the Hollow Knight demo. +The Data quantity part of the recording shows you whether we think you have enough data for a model. +The minimum required is about 16 seconds, 41 seconds is a good amount, and anything above 1 minute 22 seconds is considered excellent. +You will start getting diminishing returns after that, but the returns are still there. I used about 4 minutes per sound for the Hollow Knight demo. You can try any amount and see if they recognize well. -From this version onward, there will also be full recordings of the recording session saved in the source directory inside of the sound you are recording. This might come in handy when we start adding more sophisticated models in the future. +If you want the model to do well, you should aim to have about the same amount of recordings for every sound you record. + +### Checking the quality of the detection + +If you want to see if the detection was alright, you can either open up the SRT file inside the segments folder of your recorded sound and compare it to the source file, or use the comparison.wav file inside of the segments folder. +If you place both the source file and the comparison.wav file inside a program like Audacity, you can see the spots where it detected a sound. -You can use these source files to resegment the recordings you have made as well, by using the [V] menu at the start and then navigating to [S]. This will reuse the source files available to read out the wav data and persist them inside the data/output folder. +![Audacity comparing detection](media/settings-compare-detection.png) ### Background noise diff --git a/docs/media/settings-compare-detection.png b/docs/media/settings-compare-detection.png new file mode 100644 index 00000000..e68cb5cd Binary files /dev/null and b/docs/media/settings-compare-detection.png differ diff --git a/docs/media/settings-record-progress.png b/docs/media/settings-record-progress.png index 50de8a14..5f7cb9be 100644 Binary files a/docs/media/settings-record-progress.png and b/docs/media/settings-record-progress.png differ diff --git a/docs/media/settings-record.png b/docs/media/settings-record.png index b21cb90e..c9b4e432 100644 Binary files a/docs/media/settings-record.png and b/docs/media/settings-record.png differ diff --git a/lib/audio_dataset.py b/lib/audio_dataset.py index 6fc66810..8028007b 100644 --- a/lib/audio_dataset.py +++ b/lib/audio_dataset.py @@ -5,6 +5,7 @@ import numpy as np import random import math +from lib.wav import load_wav_data_from_srt class AudioDataset(Dataset): @@ -15,61 +16,60 @@ def __init__(self, grouped_data_directories, settings): self.augmented_samples = [] self.length = 0 self.training = False - rebuild_cache = False for index, label in enumerate( grouped_data_directories ): directories = grouped_data_directories[ label ] - listed_files = [] + listed_files = {} for directory in directories: - for file in os.listdir( directory ): - if( file.endswith(".wav") ): - listed_files.append( os.path.join(directory, file) ) + segments_directory = os.path.join(directory, "segments") + source_directory = os.path.join(directory, "source") + if not (os.path.exists(segments_directory) and os.path.exists(source_directory)): + continue + + source_files = os.listdir(source_directory) + srt_files = [x for x in os.listdir(segments_directory) if x.endswith(".srt")] + for source_file in source_files: + shared_key = source_file.replace(".wav", "") + + possible_srt_files = [x for x in srt_files if x.startswith(shared_key)] + if len(possible_srt_files) == 0: + continue + + # Find the highest version of the segmentation for this source file + srt_file = possible_srt_files[0] + for possible_srt_file in possible_srt_files: + current_version = int( srt_file.replace(".srt", "").replace(shared_key + ".v", "") ) + version = int( possible_srt_file.replace(".srt", "").replace(shared_key + ".v", "") ) + if version > current_version: + srt_file = possible_srt_file + + listed_files[os.path.join(source_directory, source_file)] = os.path.join(segments_directory, srt_file) listed_files_size = len( listed_files ) - print( f"Loading in {label}: {listed_files_size} files" ) - - for file_index, full_filename in enumerate( listed_files ): - print( str( math.floor(((file_index + 1 ) / listed_files_size ) * 100)) + "%", end="\r" ) - - # When the input length changes due to a different input type being used, we need to rebuild the cache from scratch - if (index == 0 and file_index == 0): - rebuild_cache = len(self.feature_engineering_cached(full_filename, False)) != len(self.feature_engineering_augmented(full_filename)) - - self.samples.append([full_filename, index, torch.tensor(self.feature_engineering_cached(full_filename, rebuild_cache)).float()]) - self.augmented_samples.append(None) + print( f"Loading in {label}" ) + listed_source_files = listed_files.keys() + for file_index, full_filename in enumerate( listed_source_files ): + all_samples = load_wav_data_from_srt(listed_files[full_filename], full_filename, self.settings['FEATURE_ENGINEERING_TYPE'], False) + augmented_samples = load_wav_data_from_srt(listed_files[full_filename], full_filename, self.settings['FEATURE_ENGINEERING_TYPE'], False, True) + + for sample in all_samples: + self.samples.append([full_filename, index, torch.tensor(sample).float()]) + for augmented_sample in augmented_samples: + self.augmented_samples.append([full_filename, index, torch.tensor(augmented_sample).float()]) def set_training(self, training): self.training = training - def feature_engineering_cached(self, filename, rebuild_cache=False): - # Only build a filesystem cache of feature engineering results if we are dealing with non-raw wave form - if (self.settings['FEATURE_ENGINEERING_TYPE'] != 1): - cache_dir = os.path.join(os.path.dirname(filename), "cache") - os.makedirs(cache_dir, exist_ok=True) - cached_filename = os.path.join(cache_dir, os.path.basename(filename) + "_fe") - if (os.path.isfile(cached_filename) == False or rebuild_cache == True): - data_row = training_feature_engineering(filename, self.settings) - np.savetxt( cached_filename, data_row ) - else: - cached_filename = filename - - return np.loadtxt( cached_filename, dtype='float' ) - - def feature_engineering_augmented(self, filename): - return augmented_feature_engineering(filename, self.settings) - def __len__(self): return len( self.samples ) def __getitem__(self, idx): # During training, get a 10% probability that you get an augmented sample if (self.training and random.uniform(0, 1) >= 0.9 ): - if (self.augmented_samples[idx] is None): - self.augmented_samples[idx] = [self.samples[idx][0], self.samples[idx][1], torch.tensor(self.feature_engineering_augmented(self.samples[idx][0])).float()] - return self.augmented_samples[idx][2], self.augmented_samples[idx][1] - else: - return self.samples[idx][2], self.samples[idx][1] - + if (idx in self.augmented_samples): + return self.augmented_samples[idx][2], self.augmented_samples[idx][1] + return self.samples[idx][2], self.samples[idx][1] + def get_labels(self): return self.paths diff --git a/lib/default_config.py b/lib/default_config.py index bab17a2e..ec22399c 100644 --- a/lib/default_config.py +++ b/lib/default_config.py @@ -67,4 +67,8 @@ if( SPEECHREC_ENABLED == True ): SPEECHREC_ENABLED = dragonfly_spec is not None - \ No newline at end of file +BACKGROUND_LABEL = "silence" + +# Detection strategies +CURRENT_VERSION = 1 +CURRENT_DETECTION_STRATEGY = "auto_dBFS_mend_dBFS_30ms_secondary_dBFS_reject_cont_45ms_repair" diff --git a/lib/key_poller.py b/lib/key_poller.py index b2ea49cb..deff08b9 100644 --- a/lib/key_poller.py +++ b/lib/key_poller.py @@ -30,7 +30,10 @@ def __exit__(self, type, value, traceback): def poll(self): if( IS_WINDOWS == True ): if( msvcrt.kbhit() ): - return msvcrt.getch().decode() + ch = msvcrt.getch() + if ch == b'\xe0' or ch == b'\000': + ch = msvcrt.getch() + return ch.decode() else: dr,dw,de = select.select([sys.stdin], [], [], 0) if not dr == []: diff --git a/lib/learn_data.py b/lib/learn_data.py index 193cceb3..db530043 100644 --- a/lib/learn_data.py +++ b/lib/learn_data.py @@ -22,6 +22,7 @@ from sklearn.neural_network import * from lib.combine_models import define_settings, get_current_default_settings from lib.audio_model import AudioModel +from lib.wav import load_wav_files_with_srts def learn_data(): dir_path = os.path.join( os.path.dirname( os.path.dirname( os.path.realpath(__file__)) ), DATASET_FOLDER) @@ -205,7 +206,7 @@ def load_data( dir_path, max_files, input_type ): for str_label, directories in grouped_data_directories.items(): # Add a label used for classifying the sounds id_label = get_label_for_directory( "".join( directories ) ) - cat_dataset_x, cat_dataset_labels, featureEngineeringTime = load_wav_files( directories, str_label, id_label, 0, max_files, input_type ) + cat_dataset_x, cat_dataset_labels, featureEngineeringTime = load_wav_files_with_srts( directories, str_label, id_label, 0, max_files, input_type ) totalFeatureEngineeringTime += featureEngineeringTime dataset_x.extend( cat_dataset_x ) dataset_labels.extend( cat_dataset_labels ) diff --git a/lib/machinelearning.py b/lib/machinelearning.py index 5d9ba806..260a7e46 100644 --- a/lib/machinelearning.py +++ b/lib/machinelearning.py @@ -114,7 +114,6 @@ def augmented_feature_engineering( wavFile, settings ): print( "OLD MFCC TYPE IS NOT SUPPORTED FOR TRAINING PYTORCH" ) return data_row - def get_label_for_directory( setdir ): return float( int(hashlib.sha256( setdir.encode('utf-8')).hexdigest(), 16) % 10**8 ) diff --git a/lib/migrate_data.py b/lib/migrate_data.py new file mode 100644 index 00000000..05eb626e --- /dev/null +++ b/lib/migrate_data.py @@ -0,0 +1,78 @@ +from config.config import * +import os +from lib.stream_processing import process_wav_file +from lib.print_status import create_progress_bar, clear_previous_lines, get_current_status, reset_previous_lines +from .typing import DetectionState +import time + +def check_migration(): + version_detected = CURRENT_VERSION + recording_dirs = os.listdir(RECORDINGS_FOLDER) + for file in recording_dirs: + if os.path.isdir(os.path.join(RECORDINGS_FOLDER, file)): + segments_folder = os.path.join(RECORDINGS_FOLDER, file, "segments") + if not os.path.exists(segments_folder): + version_detected = 0 + break + else: + source_files = os.listdir(os.path.join(RECORDINGS_FOLDER, file, "source")) + for source_file in source_files: + srt_file = source_file.replace(".wav", ".v" + str(CURRENT_VERSION) + ".srt") + if not os.path.exists(os.path.join(segments_folder, srt_file)): + version_detected = 0 + break + + if version_detected < CURRENT_VERSION: + print("----------------------------") + print("!! Improvement to segmentation found !!") + print("This can help improve the data gathering from your recordings which make newer models better") + print("Resegmenting your data may take a while") + migrate_data() + +def migrate_data(): + print("----------------------------") + recording_dirs = os.listdir(RECORDINGS_FOLDER) + for label in recording_dirs: + source_dir = os.path.join(RECORDINGS_FOLDER, label, "source") + if os.path.isdir(source_dir): + segments_dir = os.path.join(RECORDINGS_FOLDER, label, "segments") + if not os.path.exists(segments_dir): + os.makedirs(segments_dir) + wav_files = [x for x in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, x)) and x.endswith(".wav")] + if len(wav_files) == 0: + continue + print( "Resegmenting " + label + "..." ) + progress = 0 + progress_chunk = 1 / len( wav_files ) + skipped_amount = 0 + for index, wav_file in enumerate(wav_files): + wav_file_location = os.path.join(source_dir, wav_file) + srt_file_location = os.path.join(segments_dir, wav_file.replace(".wav", ".v" + str(CURRENT_VERSION) + ".srt")) + output_file_location = os.path.join(segments_dir, wav_file.replace(".wav", "_detection.wav")) + + # Only resegment if the new version does not exist already + if not os.path.exists(srt_file_location): + process_wav_file(wav_file_location, srt_file_location, output_file_location, [label], \ + lambda internal_progress, state: print_migration_progress(progress + (internal_progress * progress_chunk), state) ) + else: + skipped_amount += 1 + progress = index / len( wav_files ) + progress_chunk + + if progress == 1 and skipped_amount < len(wav_files): + clear_previous_lines(1) + + clear_previous_lines(1) + print( label + " resegmented!" if skipped_amount < len(wav_files) else label + " already properly segmented!" ) + + time.sleep(1) + print("Finished migrating data!") + print("----------------------------") + +def print_migration_progress(progress, state: DetectionState): + status_lines = get_current_status(state) + line_count = 1 + len(status_lines) if progress > 0 or state.state == "processing" else 0 + reset_previous_lines(line_count) if progress < 1 else clear_previous_lines(line_count) + print( create_progress_bar(progress) ) + if progress != 1: + for line in status_lines: + print( line ) \ No newline at end of file diff --git a/lib/print_status.py b/lib/print_status.py new file mode 100644 index 00000000..1b548d48 --- /dev/null +++ b/lib/print_status.py @@ -0,0 +1,118 @@ +from .typing import DetectionState +from typing import List +from .srt import ms_to_srt_timestring +import os +import sys + +# Needed to make escape characters work on Windows for some reason +if os.name == 'nt': + os.system("") +ANSI_CODE_LINE_UP = '\033[1A' +ANSI_CODE_LINE_CLEAR = '\x1b[2K' + +# If no UTF-8 characters are supported, use ascii characters instead +PROGRESS_FILLED = '#' if sys.stdout.encoding != 'utf-8' else '\u2588' +PROGRESS_AVAILABLE = '-' if sys.stdout.encoding != 'utf-8' else '\u2591' +LINE_LENGTH = 50 + +def create_progress_bar(percentage: float = 1.0) -> str: + filled_characters = round(max(0, min(LINE_LENGTH, LINE_LENGTH * percentage))) + return "".rjust(filled_characters, PROGRESS_FILLED).ljust(LINE_LENGTH, PROGRESS_AVAILABLE) + +def get_current_status(detection_state: DetectionState, extra_states: List[DetectionState] = []) -> List[str]: + total_ms_recorded = detection_state.ms_recorded + for extra_state in extra_states: + total_ms_recorded += extra_state.ms_recorded + recorded_timestring = ms_to_srt_timestring( total_ms_recorded, False) + + # Quality rating was manually established by doing some testing with added noise + # And finding the results becoming worse when the SNR went lower than 10 + quality = "" + if total_ms_recorded > 10000: + if detection_state.expected_snr >= 25: + quality = "Excellent" + elif detection_state.expected_snr >= 20: + quality = "Great" + elif detection_state.expected_snr >= 15: + quality = "Good" + elif detection_state.expected_snr >= 10: + quality = "Average" + elif detection_state.expected_snr >= 7: + quality = "Poor" + else: + quality = "Unusable" + + lines = [ + ".".ljust(LINE_LENGTH - 2, "-") + ".", + "| " + "Listening for:" + recorded_timestring.rjust(LINE_LENGTH - 19) + " |", + ] + + if detection_state.state == "recording": + lines.append("| " + "Sound Quality: " + quality.rjust(LINE_LENGTH - 20) + " |") + elif detection_state.state == "processing": + lines.append("| " + "PROCESSING...".ljust(LINE_LENGTH - 5) + " |") + elif detection_state.state == "paused": + lines.append("| " + "PAUSED - Resume using SPACE".ljust(LINE_LENGTH - 5) + " |") + else: + lines.append("| " + detection_state.state.upper().ljust(LINE_LENGTH - 5) + " |") + + lines.append("| " + ("dBFS:" + str(round(detection_state.latest_dBFS)).rjust(LINE_LENGTH - 10)) + " |") + if detection_state.advanced_logging: + lines.extend([ + "|".ljust(LINE_LENGTH - 2,"-") + "|", + "| " + "Est. values for thresholding".ljust(LINE_LENGTH - 5) + " |", + "|".ljust(LINE_LENGTH - 2,"-") + "|", + "| " + ("Noise floor (dBFS):" + str(round(detection_state.expected_noise_floor)).rjust(LINE_LENGTH - 24)) + " |", + "| " + ("SNR:" + str(round(detection_state.expected_snr)).rjust(LINE_LENGTH - 9)) + " |", + ]) + + for label in detection_state.labels: + # Quantity rating is based on 5000 30ms windows being good enough to train a label from the example model + # And 1000 30ms windows being enough to train a label decently + # With atleast 10 percent extra for a possible hold-out set during training + total_ms_detected = label.ms_detected + label.previous_detected + for extra_state in extra_states: + for extra_label in extra_state.labels: + if extra_label.label == label.label: + total_ms_detected += extra_label.ms_detected + extra_label.previous_detected + + percent_to_next = 0 + quantity = "" + if total_ms_detected < 16500: + percent_to_next = (total_ms_detected / 16500 ) * 100 + quantity = "Not enough" + elif total_ms_detected > 16500 and total_ms_detected < 41250: + percent_to_next = ((total_ms_detected - 16500) / (41250 - 16500) ) * 100 + quantity = "Sufficient" + elif total_ms_detected >= 41250 and total_ms_detected < 82500: + percent_to_next = ((total_ms_detected - 41250) / (82500 - 41250) ) * 100 + quantity = "Good" + elif total_ms_detected >= 82500: + quantity = "Excellent" + + if percent_to_next != 0: + quantity += " (" + str(round(percent_to_next)) + "%)" + + lines.extend([ + "|".ljust(LINE_LENGTH - 2,"-") + "|", + "| " + label.label.ljust(LINE_LENGTH - 5) + " |", + "| " + "Recorded: " + ms_to_srt_timestring( total_ms_detected, False ).rjust(LINE_LENGTH - 15) + " |", + "| " + "Data Quantity: " + quantity.rjust(LINE_LENGTH - 20) + " |", + ]) + + if detection_state.advanced_logging: + lines.append( "| " + ("type:" + str(label.duration_type if label.duration_type else "DETERMINING...").upper().rjust(LINE_LENGTH - 10)) + " |" ) + lines.append( "| " + ("dBFS treshold:" + str(round(label.min_dBFS, 2)).rjust(LINE_LENGTH - 19)) + " |" ) + lines.append("'".ljust(LINE_LENGTH - 2,"-") + "'") + + return lines + +def reset_previous_lines(line_count): + line = ""; + for i in range(0,line_count): + line += ANSI_CODE_LINE_UP + print(line, end=ANSI_CODE_LINE_CLEAR ) + +def clear_previous_lines(line_count): + for i in range(0,line_count): + print(ANSI_CODE_LINE_UP, end=ANSI_CODE_LINE_CLEAR ) \ No newline at end of file diff --git a/lib/record_data.py b/lib/record_data.py index cc684420..6e4976dc 100644 --- a/lib/record_data.py +++ b/lib/record_data.py @@ -1,16 +1,7 @@ from config.config import * import pyaudio -import wave import time -from time import sleep -import scipy.io.wavfile -import audioop import math -import numpy as np -from scipy.fftpack import fft -from scipy.fftpack import fftfreq -from scipy.signal import blackmanharris -from lib.machinelearning import get_loudest_freq, get_recording_power import os import glob from queue import * @@ -19,14 +10,20 @@ import sys from lib.listen import validate_microphone_input from lib.key_poller import KeyPoller -import struct +from lib.print_status import get_current_status, reset_previous_lines, clear_previous_lines +from lib.typing import DetectionLabel, DetectionState +from lib.stream_processing import CURRENT_VERSION, CURRENT_DETECTION_STRATEGY +from lib.typing import DetectionState, DetectionFrame +from lib.stream_recorder import StreamRecorder +from lib.srt import count_total_label_ms, ms_to_srt_timestring +from typing import List # Countdown from seconds to 0 def countdown( seconds ): with KeyPoller() as key_poller: for i in range( -seconds, 0 ): print("recording in... " + str(abs(i)), end="\r") - sleep( 1 ) + time.sleep( 1 ) if( record_controls(key_poller) == False ): return False; print(" ", end="\r") @@ -34,44 +31,131 @@ def countdown( seconds ): def record_controls( key_poller, recordQueue=None ): global currently_recording - global streams + global recorders ESCAPEKEY = '\x1b' SPACEBAR = ' ' + BACKSPACE = '\x08' + MINUS = '-' character = key_poller.poll() - if(character is not None): - if( character == SPACEBAR ): - print( "Recording paused!" ) + if(character is not None): + # Clear the last 3 seconds if backspace was pressed + if character == BACKSPACE or character == MINUS: + if (recorders is not None): + main_state = None + secondary_states = [] + for mic_index in recorders: + if main_state is None: + main_state = recorders[mic_index].get_detection_state() + else: + secondary_states.append(recorders[mic_index].get_detection_state()) + recorders[mic_index].pause() + should_resume = False + + # Clear and update the detection states + index = 0 + if main_state is not None: + main_state.state = "deleting" + print_status(main_state, secondary_states) + + for mic_index in recorders: + should_resume = recorders[mic_index].clear(3) + if index == 0: + main_state = recorders[mic_index].get_detection_state() + else: + secondary_states[index - 1] = secondary_states[index - 1].get_detection_state() + index += 1 + print_status(main_state, secondary_states) + + if main_state is not None: + main_state.state = "recording" + print_status(main_state, secondary_states) + + # Wait for the sound of the space bar to dissipate before continuing recording + time.sleep(0.3) + if should_resume: + for mic_index in recorders: + recorders[mic_index].resume() + elif( character == ESCAPEKEY ): + currently_recording = False + return False - if (streams is not None): - for stream in streams: - streams[stream].stop_stream() + elif character == SPACEBAR: + if( recordQueue == None ): + print( "Recording paused!" ) + + main_state = None + secondary_states = [] + if (recorders is not None): + for mic_index in recorders: + if main_state is None: + main_state = recorders[mic_index].get_detection_state() + else: + secondary_states.append(recorders[mic_index].get_detection_state()) + recorders[mic_index].pause() + recorders[mic_index].reset_label_count() + + # Do post processing and printing of the status + if main_state is not None: + index = 0 + main_state.state = "deleting" + print_status(main_state, secondary_states) + + for mic_index in recorders: + recorders[mic_index].post_processing( + lambda internal_progress, state, extra=secondary_states: print_status(main_state, extra) + ) + + # Update the states so the numbers count up nicely + if index == 0: + main_state = recorders[mic_index].get_detection_state() + else: + secondary_states[index - 1] = recorders[mic_index].get_detection_state() + index += 1 + + main_state.state = "paused" + print_status(main_state, secondary_states) # Pause the recording by looping until we get a new keypress while( True ): - - ## If the audio queue exists - make sure to clear it continuously + # If the audio queue exists - make sure to clear it continuously if( recordQueue != None ): for key in recordQueue: recordQueue[key].queue.clear() - + character = key_poller.poll() - if(character is not None): - if( character == SPACEBAR ): - print( "Recording resumed!" ) - if (streams is not None): - for stream in streams: - streams[stream].start_stream() + if character is not None: + if character == SPACEBAR: + if main_state is not None: + main_state.state = "recording" + print_status(main_state, secondary_states) + + # Wait for the sound of the space bar to dissipate before continuing recording + time.sleep(0.3) + if recorders is not None: + for mic_index in recorders: + recorders[mic_index].resume() return True - elif( character == ESCAPEKEY ): - print( "Recording stopped" ) + # Clear the last 3 seconds if backspace was pressed + elif character == BACKSPACE or character == MINUS: + if recorders is not None and main_state is not None: + index = 0 + for mic_index in recorders: + recorders[mic_index].clear(3) + if index == 0: + main_state = recorders[mic_index].get_detection_state() + else: + secondary_states[index - 1] = secondary_states[index - 1].get_detection_state() + index += 1 + print_status(main_state, secondary_states) + main_state.state = "paused" + print_status(main_state, secondary_states) + + # Stop the recording session + elif character == ESCAPEKEY: currently_recording = False return False time.sleep(0.3) - elif( character == ESCAPEKEY ): - print( "Recording stopped" ) - currently_recording = False - return False return True def record_sound(): @@ -83,34 +167,8 @@ def record_sound(): print( "And record tiny audio files to be used for learning later" ) print( "-------------------------" ) - # Note - this assumes a maximum of 10 possible input devices, which is probably wrong but eh - print("What microphone do you want to record with? ( Empty is the default system mic, [X] exits the recording menu )") - print("You can put a space in between numbers to record with multiple microphones") - for index in range(audio.get_device_count()): - device_info = audio.get_device_info_by_index(index) - if (device_info and device_info['name'] and device_info['maxInputChannels'] > 0): - default_mic = " - " if index != INPUT_DEVICE_INDEX else " DEFAULT - " - host_api = audio.get_host_api_info_by_index(device_info['hostApi']) - host_api_string = " " + host_api["name"] if host_api else "" - print("[" + str(index) + "]" + default_mic + device_info['name'] + host_api_string) - - mic_index_string = input("") - mic_indecis = [] - if mic_index_string == "": - mic_indecis = [str(INPUT_DEVICE_INDEX)] - elif mic_index_string.strip().lower() == "x": - return; - else: - mic_indecis = mic_index_string.split() - valid_mics = [] - for mic_index in mic_indecis: - if (str.isdigit(mic_index) and validate_microphone_index(audio, int(mic_index))): - valid_mics.append(int(mic_index)) - - if len(valid_mics) == 0: - print("No usable microphones selected - Exiting") - return; - + ms_per_frame = math.floor(RECORD_SECONDS / SLIDING_WINDOW_AMOUNT * 1000) + directory_counts = {} try: if os.path.exists(RECORDINGS_FOLDER): glob_path = RECORDINGS_FOLDER + "/*/" @@ -122,8 +180,18 @@ def record_sound(): # cut off glob path, but leave two more characters # at the start to account for */ # also remove the trailing slash - print(" - ", dirname[len(glob_path) - 2:-1]) + directory_name = dirname[len(glob_path) - 2:-1] + + # Count the currently recorded amount of data + current_count = count_total_label_ms(directory_name, os.path.join(RECORDINGS_FOLDER, directory_name), ms_per_frame) + directory_counts[directory_name] = current_count + time_recorded = " ( " + ms_to_srt_timestring(current_count, False).split(",")[0] + " )" + + print(" - ", directory_name.ljust(30) + time_recorded ) print("") + print("NOTE: It is recommended to record roughly the same amount for each sound") + print("As it will improve the ability for the machine learning models to learn from the data") + print("") except: # Since this is just a convenience feature, exceptions shall not # cause recording to abort, whatever happens @@ -138,35 +206,48 @@ def record_sound(): if not os.path.exists(RECORDINGS_FOLDER + "/" + directory + "/source"): os.makedirs(RECORDINGS_FOLDER + "/" + directory + "/source") - print("What signal power ( loudness ) threshold do you need?") - print("(if you do not know, start with something like 10000 and see afterwards") - print("what power values you get while recording.)") - power_threshold = input("power: ") - if( power_threshold == "" ): - power_threshold = 0 - else: - power_threshold = int( power_threshold ) - - print("What frequency threshold do you need?") - print("(you may not need this at all, so feel free to just press enter here)") - frequency_threshold = input("frequency: ") - if( frequency_threshold == "" ): - frequency_threshold = 0 + # Note - this assumes a maximum of 10 possible input devices, which is probably wrong but eh + print("What microphone do you want to record with? ( Empty is the default system mic, [X] exits the recording menu )") + print("You can put a space in between numbers to record with multiple microphones") + for index in range(audio.get_device_count()): + device_info = audio.get_device_info_by_index(index) + if (device_info and device_info['name'] and device_info['maxInputChannels'] > 0): + default_mic = " - " if index != INPUT_DEVICE_INDEX else " DEFAULT - " + host_api = audio.get_host_api_info_by_index(device_info['hostApi']) + host_api_string = " " + host_api["name"] if host_api else "" + print("[" + str(index) + "]" + default_mic + device_info['name'] + host_api_string) + + mic_index_string = input("") + mic_indecis = [] + if mic_index_string == "": + mic_indecis = [str(INPUT_DEVICE_INDEX)] + elif mic_index_string.strip().lower() == "x": + return; else: - frequency_threshold = int( frequency_threshold ) - begin_threshold = 10000 + mic_indecis = mic_index_string.split() + valid_mics = [] + for mic_index in mic_indecis: + if (str.isdigit(mic_index) and validate_microphone_index(audio, int(mic_index))): + valid_mics.append(int(mic_index)) + if len(valid_mics) == 0: + print("No usable microphones selected - Exiting") + return; + print("") - print("You can pause/resume the recording session using the [SPACE] key, and stop the recording using the [ESC] key" ) + print("Record keyboard controls:") + print("[SPACE] is used to pause and resume the recording session") + print("[BACKSPACE] or [-] removes the last 3 seconds of the recording") + print("[ESC] stops the current recording") + print("") - global streams global recordQueue - global audios - global files_recorded - files_recorded = 0 - streams = {} - audios = {} + global recorders + recorders = {} recordQueue = {} + labels = {} + labels[directory] = directory_counts[directory] if directory in directory_counts else 0 + if( countdown( 5 ) == False ): return; @@ -175,158 +256,129 @@ def record_sound(): time_string = str(int(time.time())) for index, microphone_index in enumerate(valid_mics): - FULL_WAVE_OUTPUT_FILENAME = RECORDINGS_FOLDER + "/" + directory + "/source/i_0__p_" + str(power_threshold) + \ - "__f_" + str(frequency_threshold) + "__begin_" + str(begin_threshold) + "__mici_" + str(microphone_index) + "__" + time_string + ".wav" - WAVE_OUTPUT_FILENAME = RECORDINGS_FOLDER + "/" + directory + "/" + time_string + "__mici_" + str(microphone_index) + "__file"; - WAVE_OUTPUT_FILE_EXTENSION = ".wav"; - - non_blocking_record(power_threshold, frequency_threshold, begin_threshold, WAVE_OUTPUT_FILENAME, WAVE_OUTPUT_FILE_EXTENSION, FULL_WAVE_OUTPUT_FILENAME, microphone_index, index==0) + FULL_WAVE_OUTPUT_FILENAME = RECORDINGS_FOLDER + "/" + directory + "/source/mici_" + str(microphone_index) + "__" + time_string + ".wav" + SRT_FILENAME = RECORDINGS_FOLDER + "/" + directory + "/segments/mici_" + str(microphone_index) + "__" + time_string + ".v" + str(CURRENT_VERSION) + ".srt" + non_blocking_record(labels, FULL_WAVE_OUTPUT_FILENAME, SRT_FILENAME, microphone_index, index==0) - # wait for stream to finish (5) - while currently_recording: + # wait for stream to finish + while currently_recording == True: time.sleep(0.1) - for microphone_index in valid_mics: - streams['index' + str(microphone_index)].stop_stream() - streams['index' + str(microphone_index)].close() - audios['index' + str(microphone_index)].terminate() + main_state = None + secondary_states = [] + for mic_index in recorders: + if main_state is None: + main_state = recorders[mic_index].get_detection_state() + else: + secondary_states.append(recorders[mic_index].get_detection_state()) + recorders[mic_index].pause() + + index = 0 + for mic_index in recorders: + callback = None if currently_recording == -1 else lambda internal_progress, state, extra=secondary_states: print_status(main_state, extra) + recorders[mic_index].stop( + callback + ) + + # Update the states so the numbers count up nicely + if index == 0: + main_state = recorders[mic_index].get_detection_state() + else: + secondary_states[index - 1] = recorders[mic_index].get_detection_state() + index += 1 + if currently_recording != -1: + main_state.state = "processed" + print_status(main_state, secondary_states) + # Consumes the recordings in a sliding window fashion - Always combining the two latest chunks together -def record_consumer(power_threshold, frequency_threshold, begin_threshold, WAVE_OUTPUT_FILENAME, WAVE_OUTPUT_FILE_EXTENSION, FULL_WAVE_OUTPUT_FILENAME, MICROPHONE_INPUT_INDEX, audio, streams, print_stuff=False): +def record_consumer(labels, FULL_WAVE_OUTPUT_FILENAME, SRT_FILE, MICROPHONE_INPUT_INDEX, print_stuff=False): global recordQueue global currently_recording - global files_recorded - indexedQueue = recordQueue['index' + str(MICROPHONE_INPUT_INDEX)] + global recorders + mic_index = 'index' + str(MICROPHONE_INPUT_INDEX) + indexedQueue = recordQueue[mic_index] + recorder = recorders[mic_index] - j = 0 - record_wave_file_count = 0 - audioFrames = [] - - # Set the proper thresholds for starting recordings - delay_threshold = 0 - if( begin_threshold < 0 ): - delay_threshold = begin_threshold * -1 - begin_threshold = 1000 + if print_stuff: + current_status = recorder.get_status() + for line in current_status: + print( line ) - totalAudioFrames = [] try: with KeyPoller() as key_poller: - # Write the source file first with the right settings to add the headers, and write the data later - totalWaveFile = wave.open(FULL_WAVE_OUTPUT_FILENAME, 'wb') - totalWaveFile.setnchannels(CHANNELS) - totalWaveFile.setsampwidth(audio.get_sample_size(FORMAT)) - totalWaveFile.setframerate(RATE) - totalWaveFile.close() - - # This is used to modify the wave file directly later - # Thanks to hydrogen18.com for offering the wav file explanation and code - CHUNK_SIZE_OFFSET = 4 - DATA_SUB_CHUNK_SIZE_SIZE_OFFSET = 40 - - LITTLE_ENDIAN_INT = struct.Struct('= SLIDING_WINDOW_AMOUNT ): - j+=1 - audioFrames = audioFrames[-SLIDING_WINDOW_AMOUNT:] - - byteString = b''.join(audioFrames) - fftData = np.frombuffer( byteString, dtype=np.int16 ) - frequency = get_loudest_freq( fftData, RECORD_SECONDS ) - power = get_recording_power( fftData, RECORD_SECONDS ) - - fileid = "%0.2f" % ((j) * RECORD_SECONDS ) - - if( record_controls( key_poller, recordQueue ) == False ): - for stream in streams: - streams[stream].stop_stream() - currently_recording = False - break; - - if( frequency > frequency_threshold and power > power_threshold ): - record_wave_file_count += 1 - if( record_wave_file_count <= begin_threshold and record_wave_file_count > delay_threshold ): - files_recorded += 1 - if print_stuff: - print( "Files recorded: %0d - Power: %0d - Freq: %0d - Saving %s" % ( files_recorded, power, frequency, fileid ) ) - waveFile = wave.open(WAVE_OUTPUT_FILENAME + fileid + WAVE_OUTPUT_FILE_EXTENSION, 'wb') - waveFile.setnchannels(CHANNELS) - waveFile.setsampwidth(audio.get_sample_size(FORMAT)) - waveFile.setframerate(RATE) - waveFile.writeframes(byteString) - waveFile.close() - else: - if print_stuff: - print( "Files recorded: %0d - Power: %0d - Freq: %0d" % ( files_recorded, power, frequency ) ) - else: - record_wave_file_count = 0 - if print_stuff: - print( "Files recorded: %0d - Power: %0d - Freq: %0d" % ( files_recorded, power, frequency ) ) - - # Append to the total wav file only once every ten audio frames ( roughly once every 225 milliseconds ) - if (len(totalAudioFrames) >= 15 ): - byteString = b''.join(totalAudioFrames) - totalFrameCount += len(byteString) - totalAudioFrames = [] - appendTotalFile = open(FULL_WAVE_OUTPUT_FILENAME, 'ab') - appendTotalFile.write(byteString) - appendTotalFile.close() - - # Set the amount of frames available and chunk size - # By overriding the header part of the wave file manually - # Which wouldn't be needed if the wave package supported appending properly - # Thanks to hydrogen18.com for the explanation and code - appendTotalFile = open(FULL_WAVE_OUTPUT_FILENAME, 'r+b') - appendTotalFile.seek(0,2) - chunk_size = appendTotalFile.tell() - 8 - appendTotalFile.seek(CHUNK_SIZE_OFFSET) - appendTotalFile.write(LITTLE_ENDIAN_INT.pack(chunk_size)) - appendTotalFile.seek(DATA_SUB_CHUNK_SIZE_SIZE_OFFSET) - sample_length = 2 * totalFrameCount - appendTotalFile.write(LITTLE_ENDIAN_INT.pack(sample_length)) - appendTotalFile.close() - sleep(0.001) + recorder.add_audio_frame(indexedQueue.get()) + if print_stuff: + extra_states = [] + for recorder_mic_index in recorders: + if mic_index != recorder_mic_index and recorder: + extra_states.append(recorders[recorder_mic_index].get_detection_state()) + current_status = recorder.get_status(extra_states) + reset_previous_lines(len(current_status)) + for line in current_status: + print( line ) + + # Only listen for keys in the main listener + if print_stuff: + record_controls( key_poller, recordQueue ) + + time.sleep(0.001) + except Exception as e: print( "----------- ERROR DURING RECORDING -------------- " ) exc_type, exc_value, exc_tb = sys.exc_info() traceback.print_exception(exc_type, exc_value, exc_tb) - for stream in streams: - streams[stream].stop_stream() - currently_recording = False + currently_recording = -1 def multithreaded_record( in_data, frame_count, time_info, status, queue ): queue.put( in_data ) return in_data, pyaudio.paContinue -# Records a non blocking audio stream and saves the chunks onto a queue -# The queue will be used as a sliding window over the audio, where two chunks are combined into one audio file -def non_blocking_record(power_threshold, frequency_threshold, begin_threshold, WAVE_OUTPUT_FILENAME, WAVE_OUTPUT_FILE_EXTENSION, FULL_WAVE_OUTPUT_FILENAME, MICROPHONE_INPUT_INDEX, print_logs): +# Records a non blocking audio stream and saves the source and SRT file for it +def non_blocking_record(labels, FULL_WAVE_OUTPUT_FILENAME, SRT_FILE, MICROPHONE_INPUT_INDEX, print_logs): global recordQueue - global streams - global audios + global recorders mic_index = 'index' + str(MICROPHONE_INPUT_INDEX) - + recordQueue[mic_index] = Queue(maxsize=0) - micindexed_lambda = lambda in_data, frame_count, time_info, status, queue=recordQueue['index' + str(MICROPHONE_INPUT_INDEX)]: multithreaded_record(in_data, frame_count, time_info, status, queue) - audios[mic_index] = pyaudio.PyAudio() - streams[mic_index] = audios[mic_index].open(format=FORMAT, channels=CHANNELS, - rate=RATE, input=True, - input_device_index=MICROPHONE_INPUT_INDEX, - frames_per_buffer=round( RATE * RECORD_SECONDS / SLIDING_WINDOW_AMOUNT ), - stream_callback=micindexed_lambda) - - consumer = threading.Thread(name='consumer', target=record_consumer, args=(power_threshold, frequency_threshold, begin_threshold, WAVE_OUTPUT_FILENAME, WAVE_OUTPUT_FILE_EXTENSION, FULL_WAVE_OUTPUT_FILENAME, MICROPHONE_INPUT_INDEX, audios[mic_index], streams, print_logs)) + micindexed_lambda = lambda in_data, frame_count, time_info, status, queue=recordQueue[mic_index]: multithreaded_record(in_data, frame_count, time_info, status, queue) + + detection_strategy = CURRENT_DETECTION_STRATEGY + ms_per_frame = math.floor(RECORD_SECONDS / SLIDING_WINDOW_AMOUNT * 1000) + detection_labels = [] + for label in list(labels.keys()): + detection_labels.append(DetectionLabel(label, 0, labels[label], "", 0, 0, 0, 0)) + + audio = pyaudio.PyAudio() + + recorders[mic_index] = StreamRecorder( + audio, + audio.open(format=FORMAT, channels=CHANNELS, + rate=RATE, input=True, + input_device_index=MICROPHONE_INPUT_INDEX, + frames_per_buffer=round( RATE * RECORD_SECONDS / SLIDING_WINDOW_AMOUNT ), + stream_callback=micindexed_lambda), + FULL_WAVE_OUTPUT_FILENAME, + SRT_FILE, + DetectionState(detection_strategy, "recording", ms_per_frame, 0, True, 0, 0, 0, detection_labels) + ) + + consumer = threading.Thread(name='consumer', target=record_consumer, args=(labels, FULL_WAVE_OUTPUT_FILENAME, SRT_FILE, MICROPHONE_INPUT_INDEX, print_logs)) consumer.setDaemon( True ) consumer.start() - streams[mic_index].start_stream() - + recorders[mic_index].resume() + +def print_status(detection_state: DetectionState, extra_states: List[DetectionState]): + current_status = get_current_status(detection_state, extra_states) + reset_previous_lines(len(current_status)) + for line in current_status: + print( line ) + def validate_microphone_index(audio, input_index): micDict = {'name': 'Missing Microphone index ' + str(input_index)} try: diff --git a/lib/signal_processing.py b/lib/signal_processing.py new file mode 100644 index 00000000..bdb86213 --- /dev/null +++ b/lib/signal_processing.py @@ -0,0 +1,108 @@ +import math +import numpy as np +from scipy.fftpack import fft, rfft, fft2, dct +import audioop +from python_speech_features import mfcc +from .mfsc import Mfsc +from typing import List, Tuple +import os +from config.config import RATE + +long_byte_size = 4 +_mfscs = {} + +# Determine the decibel based on full scale of 16 bit ints ( same as Audacity ) +def determine_dBFS(waveData: np.array) -> float: + return 20 * math.log10(determine_power(waveData) / math.pow(32767, 2)) + +def determine_power(waveData: np.array) -> float: + return audioop.rms(waveData, long_byte_size) + +# This power measurement is the old representation for human readability +def determine_legacy_power(waveData: np.array) -> float: + return determine_power(audioop.rms(waveData, 4)) / 1000 + +# Old fundamental frequency finder - this one doesn't show frequency in Hz +def determine_legacy_frequency(waveData: np.array) -> float: + fft_result = fft( waveData ) + positiveFreqs = np.abs( fft_result[ 0:round( len(fft_result) / 2 ) ] ) + highestFreq = 0 + loudestPeak = 500 + frequencies = [0] + for freq in range( 0, len( positiveFreqs ) ): + if( positiveFreqs[ freq ] > loudestPeak ): + loudestPeak = positiveFreqs[ freq ] + highestFreq = freq + + if( loudestPeak > 500 ): + frequencies.append( highestFreq ) + + if( recordLength < 1 ): + # Considering our sound sample is, for example, 100 ms, our lowest frequency we can find is 10Hz ( I think ) + # So add that as a base to our found frequency to get Hz - This is probably wrong + freqInHz = ( 1 / recordLength ) + np.amax( frequencies ) + else: + # I have no clue how to even pretend to know how to calculate Hz for fft frames longer than a second + freqInHz = np.amax( frequencies ) + + return freqInHz + +# Approximate vocal formants F1 and F2 using weighted average +# Goal is to have a light weight, smooth pair of values that can be properly controlled by the user +# Heuristics taken based on https://home.cc.umanitoba.ca/~krussll/phonetics/acoustic/formants.html +# 241 taken from assumption 15ms * 16khz + 1 +def determine_formant_frequencies(waveData: np.array, bin_size: float = 241) -> Tuple[float, float]: + bin_range = 8000 / bin_size + + # Check what the loudest frequency is in the 1000Hz range first + f1_range = int(bin_size / 8) + # Initially start F2 range from 1100Hz + f2_range = int(bin_size / 8 + 3) + + fft_bins = np.fft.rfft(waveData) + + f1_bins = fft_bins[:f1_range] + f1_n_loudest_bins = 5 + loudest_f1_fft_bins = np.argpartition(f1_bins, -f1_n_loudest_bins)[-f1_n_loudest_bins:] + loudest_f1_bin_values = np.take(fft_bins, loudest_f1_fft_bins) + f1_bin_sum = np.sum(loudest_f1_bin_values) + f1_weighted_avg = np.real(np.average(loudest_f1_fft_bins, weights=(loudest_f1_bin_values / f1_bin_sum))) + f1 = max(0, f1_weighted_avg) * bin_range + + # Incase the F1 is lower than 600Hz, lower the F2 range start to find low sounding vowels' F2 + if (f1 < 550): + f2_range = int(bin_size / 8 * 0.8) + + f2_bins = fft_bins[f2_range:] + f2_n_loudest_bins = 20 + loudest_f2_fft_bins = np.argpartition(f2_bins, -f2_n_loudest_bins)[-f2_n_loudest_bins:] + loudest_f2_bin_values = np.take(f2_bins, loudest_f2_fft_bins) + f2_bin_sum = np.sum(loudest_f2_bin_values) + + # Append the offset of the indexes of f2 to make sure the bins line up with the original fft bins + f2_weighted_avg = np.real(np.average([f2_bin + f2_range for f2_bin in loudest_f2_fft_bins], weights=(loudest_f2_bin_values / f2_bin_sum))) + f2 = f2_weighted_avg * bin_range + + return f1, f2 + +def determine_mfcc_type1(waveData: np.array, sampleRate: int = 16000) -> List[float]: + return mfcc( waveData, samplerate=sampleRate, nfft=1103, numcep=13, appendEnergy=True ) + +def determine_mfcc_type2(waveData: np.array, sampleRate: int = 16000) -> List[float]: + return mfcc( waveData, samplerate=sampleRate, nfft=1103, numcep=30, nfilt=40, preemph=0.5, winstep=0.005, winlen=0.015, appendEnergy=False ) + +def determine_mfsc(waveData: np.array, sampleRate:int = 16000) -> List[float]: + global _mfscs + if ( sampleRate not in _mfscs ): + _mfscs[sampleRate] = Mfsc(sr=sampleRate, n_mel=40, preem_coeff=0.5, frame_stride_ms=5, frame_size_ms=15) + _mfsc = _mfscs[sampleRate] + return _mfsc.apply( waveData ) + +# Get a feeling of how much the signal changes based on the total distance between mel frames +def determine_euclidean_dist(mfscData: np.array) -> float: + mel_frame_amount = len(mfscData) + distance = 0 + for i in range(0, mel_frame_amount): + if i > 0: + distance += np.linalg.norm(mfscData[i-1] - mfscData[i]) + return distance \ No newline at end of file diff --git a/lib/srt.py b/lib/srt.py new file mode 100644 index 00000000..eef85f8f --- /dev/null +++ b/lib/srt.py @@ -0,0 +1,328 @@ +import time +from config.config import BACKGROUND_LABEL, CURRENT_VERSION +from .typing import TransitionEvent, DetectionEvent, DetectionFrame +from typing import List +import math +import os + +def ms_to_srt_timestring( ms: int, include_hours=True): + if ms <= 0: + return "00:00:00,000" if include_hours else "00:00,000" + + if include_hours: + hours = math.floor(ms / (60 * 60 * 1000)) + ms -= hours * 60 * 60 * 1000 + minutes = math.floor(ms / (60 * 1000)) + ms -= minutes * 60 * 1000 + seconds = math.floor(ms / 1000) + ms -= seconds * 1000 + return ( "{:02d}".format(hours) + ":" if include_hours else "" ) + "{:02d}".format(minutes) + ":" + "{:02d}".format(seconds) + "," + "{:03d}".format(ms) + +def srt_timestring_to_ms( srt_timestring: str): + ms = int(srt_timestring.split(",")[1]) + ms += int(srt_timestring.split(":")[2].split(",")[0]) * 1000 + ms += int(srt_timestring.split(":")[1]) * 60 * 1000 + ms += int(srt_timestring.split(":")[0]) * 60 * 60 * 1000 + return ms + +def persist_srt_file(srt_filename: str, events: List[DetectionEvent]): + if not srt_filename.endswith(".v1.srt"): + srt_filename += ".v1.srt" + + # Sort events chronologically first + events.sort(key = lambda event: event.start_index) + with open(srt_filename, 'w') as srt_file: + for index, event in enumerate(events): + srt_file.write( str(index + 1) + '\n' ) + srt_file.write( ms_to_srt_timestring(event.start_ms) + " --> " + ms_to_srt_timestring(event.end_ms) + '\n' ) + srt_file.write( event.label + '\n\n' ) + +def parse_srt_file(srt_filename: str, rounding_ms: int, show_errors: bool = True) -> List[TransitionEvent]: + transition_events = [] + positive_event_list = [] + + if not srt_filename.endswith(".srt"): + srt_filename += ".srt" + + with open(srt_filename, "r") as srt: + time_start = 0 + time_end = 0 + type_sound = "" + for line_index, line in enumerate(srt): + if not line.strip(): + time_start = 0 + time_end = 0 + type_sound = "" + elif "-->" in line: + # Extract time start and end rounded to the window size + # To give the detection a fair estimate of correctness + time_pair = [timestring.strip() for timestring in line.split("-->")] + time_start = math.ceil(srt_timestring_to_ms( time_pair[0] ) / rounding_ms) * rounding_ms + + time_end = math.ceil(srt_timestring_to_ms( time_pair[1] ) / rounding_ms) * rounding_ms + elif not line.strip().isnumeric(): + if type_sound == "": + type_sound = line.strip() + if time_start < time_end: + positive_event_list.append(str(time_start) + "---" + type_sound + "---start") + positive_event_list.append(str(time_end) + "---" + type_sound + "---end") + elif show_errors: + print( ".SRT error at line " + str(line_index) + " - Start time not before end time! Not adding this event - Numbers won't be valid!" ) + + # Sort chronologically by time + positive_event_list.sort(key = lambda event: int(event.split("---")[0])) + for time_index, time_event in enumerate(positive_event_list): + # Remove duplicates if found + if time_index != 0 and len(transition_events) > 0 and transition_events[-1].start_index == math.floor(int(time_event.split("---")[0]) / rounding_ms): + if show_errors: + print( "Found duplicate entry at second " + str(math.floor(int(time_event.split("---")[0]) / rounding_ms) / 1000) + " - Not adding duplicate") + continue; + + if time_event.endswith("---start"): + if time_index == 0 and int(time_event.split("---")[0]) > 0: + transition_events.append( TransitionEvent(BACKGROUND_LABEL, 0, 0) ) + + ms_start = math.floor(int(time_event.split("---")[0])) + + # If the time between the end and start of a new event is 0, then the previous event should be removed + if len(transition_events) > 0 and ms_start - transition_events[-1].start_ms <= rounding_ms: + transition_events.pop() + + transition_events.append( TransitionEvent(time_event.split("---")[1], math.floor(ms_start / rounding_ms), ms_start) ) + elif time_event.endswith("---end"): + ms_start = math.floor(int(time_event.split("---")[0])) + transition_events.append( TransitionEvent(BACKGROUND_LABEL, math.floor(ms_start / rounding_ms), ms_start) ) + + return transition_events + +def count_total_label_ms(label: str, base_folder: str, rounding_ms: int) -> int: + total_ms = 0 + segments_dir = os.path.join(base_folder, "segments") + if os.path.isdir(segments_dir): + srt_files = [x for x in os.listdir(segments_dir) if os.path.isfile(os.path.join(segments_dir, x)) and x.endswith(".v" + str(CURRENT_VERSION) + ".srt")] + for srt_file in srt_files: + total_ms += count_label_ms_in_srt(label, os.path.join(segments_dir, srt_file), rounding_ms) + return total_ms + +def count_label_ms_in_srt(label: str, srt_filename: str, rounding_ms: int) -> int: + transition_events = parse_srt_file(srt_filename, rounding_ms, False) + total_ms = 0 + start_ms = -1 + for transition_event in transition_events: + if transition_event.label == label: + start_ms = transition_event.start_ms + elif start_ms > -1 and transition_event.label != label: + total_ms += transition_event.start_ms - start_ms + start_ms = -1 + + return total_ms + +def print_detection_performance_compared_to_srt(actual_frames: List[DetectionFrame], frames_to_read: int, srt_file_location: str, output_wave_file = None): + ms_per_frame = actual_frames[0].duration_ms + transition_events = parse_srt_file(srt_file_location, ms_per_frame) + detection_audio_frames = [] + total_ms = 0 + + # Detection states + detected_during_index = False + false_detections = 0 + + # Times of recognitions + total_occurrences = 0 + false_recognitions = 0 + positive_recognitions = 0 + total_recognitions = 0 + + # Statistics + ms_true_negative = 0 + ms_true_positive = 0 + ms_false_negative = 0 + ms_false_positive = 0 + + false_types = { + # Types of false negative recognitions + "lag": [], + "stutter": [], + "cutoff": [], + "full_miss": [], + # Types of false positive recognitions + "late_stop": [], + "missed_dip": [], + "false_start": [], + "full_false_positive": [], + } + + # Loop over the results and compare them against the expected transition events + index = 0 + t_index = 0 + for frame in actual_frames: + index += 1 + total_ms += ms_per_frame + + # Determine expected label + actual = frame.label + expected = BACKGROUND_LABEL + transitioning = False + if t_index < len(transition_events): + if t_index + 1 < len(transition_events) and index >= transition_events[t_index + 1].start_index: + t_index += 1 + transitioning = True + if transition_events[t_index].label != BACKGROUND_LABEL: + total_occurrences += 1 + # If the current label is a background label, we have just passed a full occurrence + # So check if it has been found during the occurrence + else: + if detected_during_index: + positive_recognitions += 1 + else: + false_recognitions += 1 + detected_during_index = False + expected = transition_events[t_index].label + + # Add a WAVE signal for each false and true positive detections + if output_wave_file is not None: + highest_amp = 65536 / 10 + signal_strength = highest_amp if actual != BACKGROUND_LABEL else 0 + if expected != actual and actual != BACKGROUND_LABEL: + signal_strength = -highest_amp + + detection_signal = np.full(int(frames_to_read / 4), int(signal_strength)) + detection_signal[::2] = 0 + detection_signal[::3] = 0 + detection_signal[::5] = 0 + detection_signal[::7] = 0 + detection_signal[::9] = 0 + detection_audio_frames.append( detection_signal ) + + if expected == actual: + # Determine false detection types + if false_detections > 0: + false_index_start = index - false_detections + false_index_end = index + + # Determine the amount of true events that have been miscategorized + current_event_index = t_index + first_index = t_index + while( false_index_start < transition_events[first_index].start_index ): + first_index -= 1 + if first_index <= 0: + first_index = 0 + break + + for ei in range(first_index - 1, current_event_index): + event_index = ei + 1 + event = transition_events[event_index] + event_start = event.start_index + event_end = transition_events[event_index + 1].start_index if event_index + 1 < len(transition_events) else len(actual_frames) - 1 + + false_event_type = "" + ms_event = 0 + if false_index_start <= event_start: + false_index_start = event_start + + # Misrecognition of the start of an event + if false_index_end < event_end: + ms_event = (false_index_end - false_index_start ) * ms_per_frame + false_event_type = "late_stop" if event.label == BACKGROUND_LABEL else "lag" + # Misrecognition of a complete event + else: + ms_event = ( event_end - false_index_start ) * ms_per_frame + + false_event_type = "missed_dip" if event.label == BACKGROUND_LABEL else "full_miss" + elif false_index_start > event_start: + + # Misrecognition in between a full event + if false_index_end < event_end: + ms_event = ( false_index_end - false_index_start ) * ms_per_frame + false_event_type = "full_false_positive" if event.label == BACKGROUND_LABEL else "stutter" + # Misrecognition of the end of an event + else: + ms_event = (event_end - false_index_start) * ms_per_frame + false_event_type = "false_start" if event.label == BACKGROUND_LABEL else "cutoff" + + if false_event_type in false_types and ms_event > 0: + false_types[false_event_type].append( ms_event ) + + # Reset the index to the start of the next event if the event can be followed by another false event + if false_event_type in ["false_start", "cutoff", "full_miss", "full_false_positive"]: + false_index_start = event_end + false_detections = 0 + + if expected != BACKGROUND_LABEL: + if detected_during_index == False: + detected_during_index = True + ms_true_positive += ms_per_frame + else: + ms_true_negative += ms_per_frame + else: + # False detections are counted by the sum of their events + false_detections += 1 + + if output_wave_file is not None: + output_wave_file.writeframes(b''.join(detection_audio_frames)) + output_wave_file.close() + + # Determine total time + ms_false_positive = 0 + ms_false_negative = 0 + for false_type in false_types: + false_types[false_type] = { + "data": false_types[false_type], + } + amount = len(false_types[false_type]["data"]) + + false_types[false_type]["times"] = amount + false_types[false_type]["avg"] = round(np.mean(false_types[false_type]["data"])) if amount > 0 else 0 + false_types[false_type]["std"] = round(np.std(false_types[false_type]["data"])) if amount > 0 else 0 + if false_type in ["late_stop", "missed_dip", "false_start", "full_false_positive"]: + ms_false_positive += round(np.sum(false_types[false_type]["data"])) + else: + ms_false_negative += round(np.sum(false_types[false_type]["data"])) + + # Export the results + export_row = [] + print("-------- Detection statistics --------") + print("Expected: " + str(total_occurrences) ) + export_row.append( str(positive_recognitions) ) + export_row.append( str(false_recognitions) ) + export_row.append( "0%" if total_occurrences == 0 else str(round(positive_recognitions / total_occurrences * 100)) + "%" ) + print("Found: " + str(positive_recognitions) + " (" + ("0%" if total_occurrences == 0 else str(round(positive_recognitions / total_occurrences * 100)) + "%)") ) + print("Missed: " + str(false_recognitions) + " (" + ("0%" if total_occurrences == 0 else str(round(false_recognitions / total_occurrences * 100)) + "%)")) + print("------------- Frame data -------------") + print("Total frames: " + str(len(actual_frames))) + export_row.append( str(round((ms_true_positive + ms_true_negative) / total_ms * 1000) / 10) + "%" ) + print("Accuracy: " + export_row[-1]) + print("-------- Positive / negative --------") + export_row.append( str(round(ms_true_positive / total_ms * 1000) / 10) + "%" ) + print("True positive: " + export_row[-1]) + export_row.append( str(round(ms_true_negative / total_ms * 1000) / 10) + "%" ) + print("True negative: " + export_row[-1]) + export_row.append( str(round(ms_false_positive / total_ms * 1000) / 10) + "%" ) + print("False positive: " + export_row[-1]) + export_row.append( str(round(ms_false_negative / total_ms * 1000) / 10) + "%" ) + print("False negative: " + export_row[-1]) + print("----------- False positives ----------") + key_length = 28 + if ms_false_positive > 0: + for fp_type in [{"key": "false_start", "name": "Early start"},{"key": "missed_dip", "name": "Missed dip"},{"key": "late_stop", "name": "Late stop"},{"key": "full_false_positive", "name": "Full FP"},]: + ms_total = sum(false_types[fp_type["key"]]["data"]) + print( (fp_type["name"] + " (% of FP):").ljust(key_length, " ") + ("0%" if ms_false_positive == 0 else str(round(ms_total / ms_false_positive * 100)) + "%") + " (" + str(false_types[fp_type["key"]]["times"]) + "x)" ) + print(" [ Average " + str(false_types[fp_type["key"]]["avg"]) + "ms (σ " + str(false_types[fp_type["key"]]["std"]) + "ms) ]") + export_row.append( str(false_types[fp_type["key"]]["times"]) ) + export_row.append( str(false_types[fp_type["key"]]["avg"]) + " σ " + str(false_types[fp_type["key"]]["std"]) if false_types[fp_type["key"]]["times"] > 0 else "0" ) + else: + export_row.extend(["0", "0", "0", "0", "0", "0", "0", "0"]) + if ms_false_negative > 0: + print("----------- False negatives ----------") + for fn_type in [{"key": "lag", "name": "Lagged start"},{"key": "stutter", "name": "Stutter"},{"key": "cutoff", "name": "Early cut-off"},{"key": "full_miss", "name": "Full miss"},]: + ms_total = sum(false_types[fn_type["key"]]["data"]) + print( (fn_type["name"] + " (% of FN):").ljust(key_length, " ") + ("0%" if ms_false_negative == 0 else str(round(ms_total / ms_false_negative * 100)) + "%") + " (" + str(false_types[fn_type["key"]]["times"]) + "x)" ) + print(" [ Average " + str(false_types[fn_type["key"]]["avg"]) + "ms (σ " + str(false_types[fn_type["key"]]["std"]) + "ms) ]") + export_row.append( str(false_types[fn_type["key"]]["times"]) ) + export_row.append( str(false_types[fn_type["key"]]["avg"]) + " σ " + str(false_types[fn_type["key"]]["std"]) if false_types[fn_type["key"]]["times"] > 0 else "0" ) + else: + export_row.extend(["0", "0", "0", "0", "0", "0", "0", "0"]) + print("--------------------------------------") + + print("Excel row") + print( " ".join(export_row) ) \ No newline at end of file diff --git a/lib/stream_processing.py b/lib/stream_processing.py new file mode 100644 index 00000000..e6c80388 --- /dev/null +++ b/lib/stream_processing.py @@ -0,0 +1,390 @@ +from .typing import DetectionLabel, DetectionFrame, DetectionEvent, DetectionState +from config.config import BACKGROUND_LABEL, RECORD_SECONDS, SLIDING_WINDOW_AMOUNT, RATE, CURRENT_VERSION, CURRENT_DETECTION_STRATEGY +from typing import List +import wave +import math +import numpy as np +from .signal_processing import determine_power, determine_dBFS, determine_mfsc, determine_euclidean_dist +from .wav import resample_audio +from .srt import persist_srt_file, print_detection_performance_compared_to_srt +import os + +def process_wav_file(input_file, srt_file, output_file, labels, progress_callback = None, comparison_srt_file = None, print_statistics = False): + audioFrames = [] + wf = wave.open(input_file, 'rb') + number_channels = wf.getnchannels() + total_frames = wf.getnframes() + frame_rate = wf.getframerate() + frames_to_read = round( frame_rate * RECORD_SECONDS / SLIDING_WINDOW_AMOUNT ) + ms_per_frame = math.floor(RECORD_SECONDS / SLIDING_WINDOW_AMOUNT * 1000) + sample_width = 2# 16 bit = 2 bytes + + detection_strategy = CURRENT_DETECTION_STRATEGY + + detection_labels = [] + for label in labels: + detection_labels.append(DetectionLabel(label, 0, 0, "", 0, 0, 0, 0)) + detection_state = DetectionState(detection_strategy, "recording", ms_per_frame, 0, True, 0, 0, 0, detection_labels) + + false_occurrence = [] + current_occurrence = [] + index = 0 + detection_frames = [] + + if progress_callback is not None: + progress_callback(0, detection_state) + + while( wf.tell() < total_frames ): + index = index + 1 + raw_wav = wf.readframes(frames_to_read * number_channels) + detection_state.ms_recorded += ms_per_frame + detected = False + + # If our wav file is shorter than the amount of bytes ( assuming 16 bit ) times the frames, we discard it and assume we arrived at the end of the file + if (len(raw_wav) != 2 * frames_to_read * number_channels ): + break; + + # Do online downsampling if the files frame rate is higher than our 16k Hz rate + # To make sure all the calculations stay accurate + raw_wav = resample_audio(raw_wav, frame_rate, number_channels) + + audioFrames.append(raw_wav) + audioFrames, detection_state, detection_frames, current_occurrence, false_occurrence = \ + process_audio_frame(index, audioFrames, detection_state, detection_frames, current_occurrence, false_occurrence) + + # Convert from different byte sizes to 16bit for proper progress + progress = wf.tell() / total_frames + if progress_callback is not None and progress < 1: + # For the initial pass we calculate 75% of the progress + # This progress partitioning is completely arbitrary + progress_callback(progress * 0.75, detection_state) + + wf.close() + + output_wave_file = wave.open(output_file, 'wb') + output_wave_file.setnchannels(number_channels) + output_wave_file.setsampwidth(sample_width) + output_wave_file.setframerate(RATE) + + post_processing(detection_frames, detection_state, srt_file, progress_callback, output_wave_file, comparison_srt_file, print_statistics ) + progress = 1 + if progress_callback is not None: + progress_callback(progress, detection_state) + +def process_audio_frame(index, audioFrames, detection_state, detection_frames, current_occurrence, false_occurrence): + detection_frames.append(determine_detection_frame(index, detection_state, audioFrames)) + detected = detection_frames[-1].positive + detected_label = detection_frames[-1].label + if detected: + current_occurrence.append(detection_frames[-1]) + else: + false_occurrence.append(detection_frames[-1]) + + # Recalculate the noise floor / signal strength every 10 frames + # For performance reason and because the statistical likelyhood of things changing every 150ms is pretty low + if len(detection_frames) % 10 == 0: + detection_state = determine_detection_state(detection_frames, detection_state) + + # On-line rejection - This may be undone in post-processing later + # Only add occurrences longer than 75 ms as no sound a human produces is shorter + if detected == False and len(current_occurrence) > 0: + is_continuous = False + for label in detection_state.labels: + if label == current_occurrence[0].label: + is_continuous = label.duration_type == "continuous" + break + + if is_rejected(detection_state.strategy, current_occurrence, detection_state.ms_per_frame, is_continuous): + total_rejected_frames = len(current_occurrence) + for frame_index in range(-total_rejected_frames - 1, 0, 1): + rejected_frame_index = frame_index + detection_frames[rejected_frame_index].label = BACKGROUND_LABEL + detection_frames[rejected_frame_index].positive = False + current_occurrence = [] + # On-line mending - This may be undone in post-processing later + # Only keep false detections longer than a certain amount ( because a human can't make them shorter ) + elif detected and len(false_occurrence) > 0: + if is_mended(detection_state.strategy, false_occurrence, detection_state, detected_label): + total_mended_frames = len(false_occurrence) + for frame_index in range(-total_mended_frames - 1, 0, 1): + mended_frame_index = frame_index + detection_frames[mended_frame_index].label = detected_label + detection_frames[mended_frame_index].positive = True + false_occurrence = [] + + return audioFrames, detection_state, detection_frames, current_occurrence, false_occurrence + +def determine_detection_frame(index, detection_state, audioFrames) -> DetectionFrame: + detected = False + if( len( audioFrames ) >= SLIDING_WINDOW_AMOUNT ): + audioFrames = audioFrames[-SLIDING_WINDOW_AMOUNT:] + + byteString = b''.join(audioFrames) + wave_data = np.frombuffer( byteString, dtype=np.int16 ) + power = determine_power( wave_data ) + dBFS = determine_dBFS( wave_data ) + mfsc_data = determine_mfsc( wave_data, RATE ) + distance = determine_euclidean_dist( mfsc_data ) + + # Attempt to detect a label + detected_label = BACKGROUND_LABEL + for label in detection_state.labels: + if is_detected(detection_state.strategy, power, dBFS, distance, label.min_dBFS): + detected = True + label.ms_detected += detection_state.ms_per_frame + detected_label = label.label + break + + return DetectionFrame(index, detection_state.ms_per_frame, detected, power, dBFS, distance, mfsc_data, detected_label) + else: + return DetectionFrame(index, detection_state.ms_per_frame, detected, 0, 0, 0, [], BACKGROUND_LABEL) + +def post_processing(frames: List[DetectionFrame], detection_state: DetectionState, output_filename: str, progress_callback = None, output_wave_file: wave.Wave_write = None, comparison_srt_file: str = None, print_statistics = False) -> List[DetectionFrame]: + detection_state.state = "processing" + if progress_callback is not None: + progress_callback(0, detection_state) + + # Do a full pass on all the frames again to fix labels we might have missed + if "repair" in detection_state.strategy: + current_occurrence = [] + false_occurrence = [] + current_label = None + detected_label = None + + # Recalculate the MS detection and duration type + for label in detection_state.labels: + label.ms_detected = 0 + label.duration_type = determine_duration_type(label, frames) + + for index, frame in enumerate(frames): + detected = False + for label in detection_state.labels: + if is_detected(detection_state.strategy, frame.power, frame.dBFS, frame.euclid_dist, label.min_dBFS): + detected = True + label.ms_detected += detection_state.ms_per_frame + current_label = label + break + + # Do a secondary pass if the previous label was negative + # As we can use its thresholds for correcting late starts + mending_offset = 0 + if detected and not frames[index - 1].positive: + for label in detection_state.labels: + if current_label.label == label.label and is_detected_secondary(detection_state.strategy, frames[index - 1].power, frames[index - 1].dBFS, frames[index - 1].euclid_dist, label.min_dBFS - 4): + label.ms_detected += detection_state.ms_per_frame + frames[index - 1].label = current_label.label + frames[index - 1].positive = True + mending_offset = -1 + if len(false_occurrence) > 0: + false_occurrence.pop() + + # Only do two frames of late start fixing as longer late starts statistically do not seem to occur + if not frames[index - 2].positive and is_detected_secondary(detection_state.strategy, frames[index - 2].power, frames[index - 2].dBFS, frames[index - 2].euclid_dist, label.min_dBFS - 4): + label.ms_detected += detection_state.ms_per_frame + frames[index - 2].label = current_label.label + frames[index - 2].positive = True + mending_offset = -2 + if len(false_occurrence) > 0: + false_occurrence.pop() + break + + if detected: + current_occurrence.append(frame) + frame.label = current_label.label + frame.positive = True + frames[index] = frame + + if len(false_occurrence) > 0: + if is_mended(detection_state.strategy, false_occurrence, detection_state, current_label.label): + total_mended_frames = len(false_occurrence) + current_label.ms_detected += total_mended_frames * detection_state.ms_per_frame + for frame_index in range(-total_mended_frames - 1 + mending_offset, mending_offset, 1): + mended_frame_index = index + frame_index + frames[mended_frame_index].label = current_label.label + frames[mended_frame_index].positive = True + false_occurrence = [] + + if not detected: + false_occurrence.append(frame) + frame.positive = False + frame.label = BACKGROUND_LABEL + frames[index] = frame + + if len(current_occurrence) > 0: + is_continuous = False + for label in detection_state.labels: + if label == current_occurrence[0].label: + is_continuous = label.duration_type == "continuous" + break + + if is_rejected(detection_state.strategy, current_occurrence, detection_state.ms_per_frame, is_continuous): + total_rejected_frames = len(current_occurrence) + current_label.ms_detected -= total_rejected_frames * detection_state.ms_per_frame + current_label = None + for frame_index in range(-total_rejected_frames - 1, 0, 1): + rejected_frame_index = index + frame_index + frames[rejected_frame_index].label = BACKGROUND_LABEL + frames[rejected_frame_index].positive = False + current_occurrence = [] + + progress = index / len(frames) + if progress_callback is not None and progress < 1: + # For the post processing phase - we count the remaining 25% of the progress + # This progress partitioning is completely arbitrary + progress_callback(0.75 + ( progress * 0.25 ), detection_state) + + + # Persist the SRT file + events = detection_frames_to_events(frames) + persist_srt_file( output_filename, events ) + + comparisonOutputWaveFile = None + if print_statistics: + if output_wave_file is not None: + comparisonOutputWaveFile = wave.open(output_filename + "_comparison.wav", 'wb') + comparisonOutputWaveFile.setnchannels(output_wave_file.getnchannels()) + comparisonOutputWaveFile.setsampwidth(output_wave_file.getsampwidth()) + comparisonOutputWaveFile.setframerate(output_wave_file.getframerate()) + + print_detection_performance_compared_to_srt(frames, detection_state.ms_per_frame, comparison_srt_file, comparisonOutputWaveFile) + + # Persist the detection wave file + if output_wave_file is not None: + frames_to_write = round( RATE * RECORD_SECONDS / SLIDING_WINDOW_AMOUNT ) + sample_width = 2# 16 bit = 2 bytes + detection_audio_frames = [] + for frame in frames: + highest_amp = 65536 / 10 + signal_strength = highest_amp if frame.positive else 0 + + detection_signal = np.full(int(frames_to_write / sample_width), int(signal_strength)) + detection_signal[::2] = 0 + detection_signal[::3] = 0 + detection_signal[::5] = 0 + detection_signal[::7] = 0 + detection_signal[::9] = 0 + detection_audio_frames.append( detection_signal ) + output_wave_file.writeframes(b''.join(detection_audio_frames)) + output_wave_file.close() + + detection_state.state = "recording" + return frames + +def determine_detection_state(detection_frames: List[DetectionFrame], detection_state: DetectionState) -> DetectionState: + # Filter out very low power dbFS values as we can assume the hardware microphone is off + # And we do not want to skew the mean for that as it would create more false positives + # ( -70 dbFS was selected as a cut off after a bit of testing with a HyperX Quadcast microphone ) + dBFS_frames = [x.dBFS for x in detection_frames if x.dBFS > -70] + std_dbFS = np.std(dBFS_frames) + + minimum_dBFS = np.min(dBFS_frames) + + # For noisy signals and for clean signals we need different noise floor and threshold estimation + # Because noisy thresholds have a lower standard deviation across the signal + # Whereas clean signals have a very clear floor and do not need as high of a threshold + noisy_threshold = False + detection_state.expected_snr = math.floor(std_dbFS * 2) + if detection_state.expected_snr < 25: + noisy_threshold = True + detection_state.expected_noise_floor = minimum_dBFS + std_dbFS + else: + detection_state.expected_noise_floor = minimum_dBFS + + for label in detection_state.labels: + + # Recalculate the duration type every 15 seconds + if label.duration_type == "" or len(detection_frames) % round(15 / RECORD_SECONDS): + label.duration_type = determine_duration_type(label, detection_frames) + label.min_dBFS = detection_state.expected_noise_floor + ( detection_state.expected_snr if noisy_threshold else detection_state.expected_snr / 2 ) + detection_state.latest_dBFS = detection_frames[-1].dBFS + return detection_state + +# Approximately determine whether the label in the stream is discrete or continuous +# Discrete sounds are from a single source event like a click, tap or a snap +# Whereas continuous sounds have a steady stream of energy from a source +def determine_duration_type(label: DetectionLabel, detection_frames: List[DetectionFrame]) -> str: + label_events = [x for x in detection_frames_to_events(detection_frames) if x.label == label.label] + if len(label_events) < 4: + return "" + else: + # The assumption here is that discrete sounds cannot vary in length much as you cannot elongate the sound of a click for example + # So if the length doesn't vary much, we assume discrete over continuous + lengths = [x.end_ms - x.start_ms for x in label_events] + continuous_length_threshold = 35 + return "discrete" if np.std(lengths) < continuous_length_threshold else "continuous" + +def detection_frames_to_events(detection_frames: List[DetectionFrame]) -> List[DetectionEvent]: + events = [] + current_label = "" + current_frames = [] + for frame in detection_frames: + if frame.label != current_label: + if len(current_frames) > 0: + events.append( DetectionEvent(current_label, current_frames[0].index, current_frames[-1].index, \ + (current_frames[0].index - 1) * current_frames[0].duration_ms, (current_frames[-1].index) * current_frames[-1].duration_ms, current_frames) ) + current_frames = [] + current_label = frame.label + + if current_label != BACKGROUND_LABEL: + current_frames.append( frame ) + + if len(current_frames) > 0: + events.append( DetectionEvent(current_label, current_frames[0].index, current_frames[-1].index, \ + (current_frames[0].index - 1) * current_frames[0].duration_ms, (current_frames[-1].index) * current_frames[-1].duration_ms, current_frames) ) + current_frames = [] + return events + +def auto_decibel_detection(power, dBFS, distance, dBFS_threshold): + return dBFS > dBFS_threshold + +def auto_secondary_decibel_detection(power, dBFS, distance, dBFS_threshold): + return dBFS > dBFS_threshold - 7 + +def is_detected(strategy, power, dBFS, distance, estimated_threshold): + if "auto_dBFS" in strategy: + return auto_decibel_detection(power, dBFS, distance, estimated_threshold) + +def is_rejected( strategy, occurrence, ms_per_frame, continuous = False ): + if "reject" not in strategy: + return False + elif "reject_45ms" in strategy: + return len(occurrence) * ms_per_frame < 45 + elif "reject_60ms" in strategy: + return len(occurrence) * ms_per_frame < 60 + elif "reject_75ms" in strategy: + return len(occurrence) * ms_per_frame < 75 + elif "reject_90ms" in strategy: + return len(occurrence) * ms_per_frame < 90 + elif "reject_cont_45ms" in strategy: + return len(occurrence) * ms_per_frame < ( 45 if continuous else 0 ) + +def is_detected_secondary( strategy, power, dBFS, distance, estimated_threshold ): + if "secondary" not in strategy: + return False + elif "secondary_dBFS" in strategy: + return auto_secondary_decibel_detection(power, dBFS, distance, estimated_threshold) + +def is_mended( strategy, occurrence, detection_state, current_label ): + if "mend" not in strategy: + return False + elif "mend_60ms" in strategy: + return len(occurrence) * detection_state.ms_per_frame < 60 + elif "mend_45ms" in strategy: + return len(occurrence) * detection_state.ms_per_frame < 45 + elif "mend_dBFS" in strategy: + label_dBFS_threshold = 0 + for label in detection_state.labels: + if label.label == current_label: + label_dBFS_threshold = label.min_dBFS + + total_missed_length_ms = 0 + for frame in occurrence: + if not auto_secondary_decibel_detection(frame.power, frame.dBFS, frame.euclid_dist, label_dBFS_threshold): + if not "mend_dBFS_30ms" in strategy: + return False + else: + total_missed_length_ms += detection_state.ms_per_frame + if not "mend_dBFS_30ms" in strategy: + return True + else: + return total_missed_length_ms < 30 diff --git a/lib/stream_recorder.py b/lib/stream_recorder.py new file mode 100644 index 00000000..8de924d9 --- /dev/null +++ b/lib/stream_recorder.py @@ -0,0 +1,195 @@ +from config.config import * +import pyaudio +import struct +import wave +import math +import numpy as np +from lib.print_status import get_current_status +from lib.stream_processing import process_audio_frame, post_processing +from lib.typing import DetectionState, DetectionFrame +from typing import List +import io + +class StreamRecorder: + total_wav_filename: str + srt_filename: str + comparison_wav_filename: str + + audio: pyaudio.PyAudio + stream: pyaudio.Stream + detection_state: DetectionState + + length_per_frame: int + audio_frames: List[np.array] + total_audio_frames: List[np.array] + index: int + detection_frames: List[DetectionFrame] + current_occurrence: List[DetectionFrame] + false_occurrence: List[DetectionFrame] + + def __init__(self, audio: pyaudio.PyAudio, stream: pyaudio.Stream, total_wav_filename: str, srt_filename: str, detection_state: DetectionState): + self.total_wav_filename = total_wav_filename + self.srt_filename = srt_filename + self.comparison_wav_filename = srt_filename.replace(".v" + str(CURRENT_VERSION) + ".srt", "_comparison.wav") + + self.audio = audio + self.stream = stream + self.detection_state = detection_state + self.total_audio_frames = [] + self.audio_frames = [] + self.detection_frames = [] + self.current_occurrence = [] + self.false_occurrence = [] + self.index = 0 + self.length_per_frame = 0 + + # Write the source file first with the right settings to add the headers, and write the data later + totalWaveFile = wave.open(self.total_wav_filename, 'wb') + totalWaveFile.setnchannels(CHANNELS) + totalWaveFile.setsampwidth(audio.get_sample_size(FORMAT)) + totalWaveFile.setframerate(RATE) + totalWaveFile.close() + + # Add a single audio frame to the batch and start processing it + def add_audio_frame(self, frame: List[np.array]): + if self.length_per_frame == 0: + self.length_per_frame = len(frame) + + self.index += 1 + self.audio_frames.append(frame) + self.detection_state.ms_recorded += self.detection_state.ms_per_frame + audioFrames, detection_state, detection_frames, current_occurrence, false_occurrence = \ + process_audio_frame(self.index, self.audio_frames, self.detection_state, self.detection_frames, self.current_occurrence, self.false_occurrence) + + self.current_occurence = current_occurrence + self.false_occurrence = false_occurrence + self.detection_state = detection_state + self.detection_frames = detection_frames + self.audio_frames = audioFrames + self.total_audio_frames.append( audioFrames[-1] ) + + # Append to the total wav file only once every fifteen audio frames + # This is roughly once every 225 milliseconds + if len(self.total_audio_frames) >= 15: + self.persist_total_wav_file() + + def persist_total_wav_file(self): + # This is used to modify the wave file directly + CHUNK_SIZE_OFFSET = 4 + DATA_SUB_CHUNK_SIZE_SIZE_OFFSET = 40 + LITTLE_ENDIAN_INT = struct.Struct(' bool: + should_resume = self.detection_state != "paused" + self.pause() + + ms_per_frame = self.detection_state.ms_per_frame + frames_to_remove = math.floor(seconds * 1000 / ms_per_frame) + clear_file = False + if (self.index < frames_to_remove): + clear_file = True + self.index -= self.index if clear_file else frames_to_remove + self.current_occurrence = [] + self.false_occurrence = [] + + self.detection_frames = self.detection_frames[:-frames_to_remove] + self.detection_state.ms_recorded = len(self.detection_frames) * ms_per_frame + for label in self.detection_state.labels: + label.ms_detected = 0 + for frame in self.detection_frames: + if frame.label == label.label: + label.ms_detected += ms_per_frame + + # Just completely overwrite the file if we go back to the start for simplicities sake + if clear_file: + totalWaveFile = wave.open(self.total_wav_filename, 'wb') + totalWaveFile.setnchannels(CHANNELS) + totalWaveFile.setsampwidth(self.audio.get_sample_size(FORMAT)) + totalWaveFile.setframerate(RATE) + totalWaveFile.close() + + # Truncate the frames from the total wav file + else: + with open(self.total_wav_filename, 'r+b') as f: + # Drop the last N bytes from the file + f.seek(-frames_to_remove * self.length_per_frame, io.SEEK_END) + f.truncate() + + # Overwrite the total recording length + CHUNK_SIZE_OFFSET = 4 + DATA_SUB_CHUNK_SIZE_SIZE_OFFSET = 40 + LITTLE_ENDIAN_INT = struct.Struct(' DetectionState: + return self.detection_state + + def get_status(self, detection_states: List[DetectionState] = []) -> List[str]: + return get_current_status(self.detection_state, detection_states) + + # Stop processing the streams and build the final files + def stop(self, callback = None): + self.pause() + self.persist_total_wav_file() + if self.index == 0: + os.remove(self.total_wav_filename) + + comparison_wav_file = wave.open(self.comparison_wav_filename, 'wb') + comparison_wav_file.setnchannels(1) + comparison_wav_file.setsampwidth(2) + comparison_wav_file.setframerate(RATE) + post_processing(self.detection_frames, self.detection_state, self.srt_filename, callback, comparison_wav_file) + self.stream.close() + self.audio.terminate() + self.detection_frames = [] + + # Do all post processing related tasks that cannot be done during runtime + def post_processing(self, callback = None, comparison_wav_file: wave.Wave_write = None): + self.persist_total_wav_file() + post_processing(self.detection_frames, self.detection_state, self.srt_filename, callback, comparison_wav_file) diff --git a/lib/typing.py b/lib/typing.py new file mode 100644 index 00000000..5d485156 --- /dev/null +++ b/lib/typing.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import List + +@dataclass +class TransitionEvent: + label: str + start_index: int + start_ms: int + +@dataclass +class DetectionFrame: + index: int + duration_ms: int + positive: bool + power: float + dBFS: float + euclid_dist: float + mel_data: List[List[float]] + label: str + +@dataclass +class DetectionEvent: + label: str + + # Based on wave indecis + start_index: int + end_index: int + start_ms: int + end_ms: int + frames: List[DetectionFrame] + +@dataclass +class DetectionLabel: + label: str + ms_detected: int + previous_detected: int + duration_type: str + + min_ms: float + min_dBFS: float + min_distance: float + max_distance: float + +@dataclass +class DetectionState: + strategy: str + state: str + ms_per_frame: int + ms_recorded: int + advanced_logging: bool + + latest_dBFS: float + expected_snr: float + expected_noise_floor: float + labels: List[DetectionLabel] \ No newline at end of file diff --git a/lib/wav.py b/lib/wav.py new file mode 100644 index 00000000..4b532179 --- /dev/null +++ b/lib/wav.py @@ -0,0 +1,127 @@ +import wave +from config.config import BACKGROUND_LABEL, RECORD_SECONDS, SLIDING_WINDOW_AMOUNT, RATE, TYPE_FEATURE_ENGINEERING_NORM_MFSC, PYTORCH_AVAILABLE +from lib.machinelearning import feature_engineering_raw +from .srt import parse_srt_file +import numpy as np +import audioop +from typing import List +import os +import time +import math +if (PYTORCH_AVAILABLE == True): + from audiomentations import Compose, AddGaussianNoise, Shift, TimeStretch + +# Resamples the audio down to 16kHz ( or any other RATE filled in ) +# To make sure all the other calculations are stable and correct +def resample_audio(wavData: np.array, frame_rate, number_channels) -> np.array: + if frame_rate > RATE: + sample_width = 2# 16 bit = 2 bytes + wavData, _ = audioop.ratecv(wavData, sample_width, number_channels, frame_rate, RATE, None) + if number_channels > 1: + wavData = audioop.tomono(wavData[0], 2, 1, 0) + return wavData + +def load_wav_files_with_srts( directories, label, int_label, start, end, input_type ): + category_dataset_x = [] + category_dataset_labels = [] + totalFeatureEngineeringTime = 0 + category_file_index = 0 + + for directory in directories: + source_directory = os.path.join( directory, "source" ) + segments_directory = os.path.join( directory, "segments" ) + + srt_files = [] + + for fileindex, file in enumerate(os.listdir(segments_directory)): + if file.endswith(".srt"): + srt_files.append(file) + + for source_index, source_file in enumerate(os.listdir(source_directory)): + if source_file.endswith(".wav"): + full_filename = os.path.join(source_directory, source_file) + print( "Loading " + str(category_file_index) + " files for " + label + "... ", end="\r" ) + category_file_index += 1 + + # Find the SRT files available for this source file + shared_key = source_file.replace(".wav", "") + possible_srt_files = [x for x in srt_files if x.startswith(shared_key)] + if len(possible_srt_files) == 0: + continue + + # Find the highest version of the segmentation for this source file + srt_file = possible_srt_files[0] + for possible_srt_file in possible_srt_files: + current_version = int( srt_file.replace(".srt", "").replace(shared_key + ".v", "") ) + version = int( possible_srt_file.replace(".srt", "").replace(shared_key + ".v", "") ) + if version > current_version: + srt_file = possible_srt_file + full_srt_filename = os.path.join(segments_directory, srt_file) + + # Load the WAV file and turn it into a onedimensional array of numbers + feature_engineering_start = time.time() * 1000 + data = load_wav_data_from_srt(full_srt_filename, full_filename, input_type, False) + category_dataset_x.extend( data ) + category_dataset_labels.extend([ label for data_row in data ]) + totalFeatureEngineeringTime += time.time() * 1000 - feature_engineering_start + + print( "Loaded " + str( len( category_dataset_labels ) ) + " .wav files for category " + label + " (id: " + str(int_label) + ")" ) + return category_dataset_x, category_dataset_labels, totalFeatureEngineeringTime + +def augment_wav_data(wavData, sample_rate): + augmenter = Compose([ + AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5), + TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5), + Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5), + ]) + return augmenter(samples=np.array(wavData, dtype="float32"), sample_rate=sample_rate) + +def load_wav_data_from_srt(srt_file: str, source_file: str, feature_engineering_type = TYPE_FEATURE_ENGINEERING_NORM_MFSC, with_offset = True, should_augment=False) -> List[List[float]]: + wav_file_data = [] + wf = wave.open(source_file, 'rb') + frame_rate = wf.getframerate() + number_channels = wf.getnchannels() + total_frames = wf.getnframes() + frames_to_read = round( frame_rate * RECORD_SECONDS / SLIDING_WINDOW_AMOUNT ) + ms_per_frame = math.floor(RECORD_SECONDS / SLIDING_WINDOW_AMOUNT * 1000) + + # If offsets are required - We seek half a frame behind the expected frame to get more data from a different location + halfframe_offset = round( frames_to_read * number_channels * 0.5 ) + start_offsets = [0, -halfframe_offset] if with_offset else [0] + + transition_events = parse_srt_file(srt_file, ms_per_frame) + for index, transition_event in enumerate(transition_events): + next_event_index = total_frames / frames_to_read if index + 1 >= len(transition_events) else transition_events[index + 1].start_index + audioFrames = [] + + if transition_event.label != BACKGROUND_LABEL: + for offset in start_offsets: + # Skip of the offset makes the position before the start of the file + if offset + (frames_to_read * transition_event.start_index) < 0: + continue; + wf.setpos(offset + (frames_to_read * transition_event.start_index)) + + keep_collecting = True + while keep_collecting: + raw_wav = wf.readframes(frames_to_read * number_channels) + + # Reached the end of wav - do not keep collecting + if (len(raw_wav) != SLIDING_WINDOW_AMOUNT * frames_to_read * number_channels ): + keep_collecting = False + break + + raw_wav = resample_audio(raw_wav, frame_rate, number_channels) + audioFrames.append(raw_wav) + if( len( audioFrames ) >= SLIDING_WINDOW_AMOUNT ): + audioFrames = audioFrames[-SLIDING_WINDOW_AMOUNT:] + + byteString = b''.join(audioFrames) + wave_data = np.frombuffer( byteString, dtype=np.int16 ) + if should_augment and PYTORCH_AVAILABLE: + wave_data = augment_wav_data(wave_data, RATE) + wav_file_data.append( feature_engineering_raw(wave_data, RATE, 0, RECORD_SECONDS, feature_engineering_type)[0] ) + + if wf.tell() >= ( next_event_index * frames_to_read ) + offset: + keep_collecting = False + + return wav_file_data \ No newline at end of file diff --git a/settings.py b/settings.py index 4245dcd3..4db104bf 100644 --- a/settings.py +++ b/settings.py @@ -4,6 +4,7 @@ from lib.test_data import test_data from lib.convert_files import convert_files from lib.combine_models import combine_models +from lib.migrate_data import check_migration def root_navigation( first): if( first ): @@ -41,5 +42,6 @@ def select_mode(): root_navigation( False ) elif( setup_mode.lower() == 'x' ): print( "Goodbye." ) - + +check_migration() root_navigation( True ) \ No newline at end of file