From f217aafcc91830add7c0d6ff02aaef4ef201fcc7 Mon Sep 17 00:00:00 2001 From: gudgud96 Date: Tue, 14 May 2024 17:20:21 +0800 Subject: [PATCH] [feature] add num_threads and worker count for madmom beat extraction --- omnizart/drum/app.py | 8 +++-- omnizart/feature/beat_for_drum.py | 58 +++++++++++++++++++------------ omnizart/feature/wrapper_func.py | 9 +++-- 3 files changed, 49 insertions(+), 26 deletions(-) diff --git a/omnizart/drum/app.py b/omnizart/drum/app.py index aed3da3..f7dda75 100644 --- a/omnizart/drum/app.py +++ b/omnizart/drum/app.py @@ -32,7 +32,7 @@ def __init__(self): super().__init__(DrumSettings) self.custom_objects = {"ConvSN2D": ConvSN2D} - def transcribe(self, input_audio, model_path=None, output="./"): + def transcribe(self, input_audio, model_path=None, output="./", beat_tracker_num_threads=3, beat_tracker_parallel_workers=3): """Transcribe drum in the audio. This function transcribes drum activations in the music. Currently the model @@ -62,7 +62,11 @@ def transcribe(self, input_audio, model_path=None, output="./"): # Extract feature according to model configuration logger.info("Extracting feature...") - patch_cqt_feature, mini_beat_arr = extract_patch_cqt(input_audio) + patch_cqt_feature, mini_beat_arr = extract_patch_cqt( + input_audio, + num_threads=beat_tracker_num_threads, + num_workers=beat_tracker_parallel_workers + ) # Load model configurations logger.info("Loading model...") diff --git a/omnizart/feature/beat_for_drum.py b/omnizart/feature/beat_for_drum.py index 978a1dc..4321b99 100644 --- a/omnizart/feature/beat_for_drum.py +++ b/omnizart/feature/beat_for_drum.py @@ -25,8 +25,9 @@ class MadmomBeatTracking: Three different beat tracking methods are used together for producing a more stable beat tracking result. """ - def __init__(self, num_threads=3): + def __init__(self, num_threads=3, parallel_workers=3): self.num_threads = num_threads + self.parallel_workers=parallel_workers def _get_dbn_down_beat(self, audio_data_in1, min_bpm_in=50, max_bpm_in=230): proccesor = DBNDownBeatTrackingProcessor( @@ -51,23 +52,30 @@ def _get_beat(self, audio_data_in3): def process(self, audio_data): """Generate beat tracking results with multiple approaches.""" - with ProcessPoolExecutor(max_workers=3) as executor: - logger.debug("Submitting and executing parallel beat tracking jobs") - future_1 = executor.submit(self._get_dbn_down_beat, audio_data, min_bpm_in=50, max_bpm_in=230) - future_2 = executor.submit(self._get_dbn_beat, audio_data) - future_3 = executor.submit(self._get_beat, audio_data) - - queue = {future_1: "dbn_down_beat", future_2: "dbn_beat", future_3: "beat"} - - results = {} - for future in concurrent.futures.as_completed(queue, timeout=600): - func_name = queue[future] - results[func_name] = future.result() - logger.debug("Job %s finished.", func_name) - - pred_beats1 = results["dbn_down_beat"] - pred_beats2 = results["dbn_beat"] - pred_beats3 = results["beat"] + if self.parallel_workers == 0: + # Run sequentially + logger.debug("Running beat tracking sequentially...") + pred_beats1 = self._get_dbn_down_beat(audio_data, min_bpm_in=50, max_bpm_in=230) + pred_beats2 = self._get_dbn_beat(audio_data) + pred_beats3 = self._get_beat(audio_data) + else: + with ProcessPoolExecutor(max_workers=self.parallel_workers) as executor: + logger.debug("Submitting and executing parallel beat tracking jobs") + future_1 = executor.submit(self._get_dbn_down_beat, audio_data, min_bpm_in=50, max_bpm_in=230) + future_2 = executor.submit(self._get_dbn_beat, audio_data) + future_3 = executor.submit(self._get_beat, audio_data) + + queue = {future_1: "dbn_down_beat", future_2: "dbn_beat", future_3: "beat"} + + results = {} + for future in concurrent.futures.as_completed(queue, timeout=600): + func_name = queue[future] + results[func_name] = future.result() + logger.debug("Job %s finished.", func_name) + + pred_beats1 = results["dbn_down_beat"] + pred_beats2 = results["dbn_beat"] + pred_beats3 = results["beat"] pred_beat_len1 = np.mean( np.sort(pred_beats1[1:] - pred_beats1[:-1])[int(len(pred_beats1) * 0.2):int(len(pred_beats1) * 0.8)] @@ -89,7 +97,7 @@ def process(self, audio_data): return self._get_dbn_down_beat(audio_data, min_bpm_in=pred_bpm_avg / 1.38, max_bpm_in=pred_bpm_avg * 1.38) -def extract_beat_with_madmom(audio_path, sampling_rate=44100): +def extract_beat_with_madmom(audio_path, sampling_rate=44100, parallel_workers=3, num_threads=3): """Extract beat position (in seconds) of the audio. Extract beat with mixture of beat tracking techiniques using madmom. @@ -111,7 +119,8 @@ def extract_beat_with_madmom(audio_path, sampling_rate=44100): logger.debug("Loading audio: %s", audio_path) audio_data, _ = load_audio(audio_path, sampling_rate=sampling_rate) logger.debug("Runnig beat tracking...") - return MadmomBeatTracking().process(audio_data), len(audio_data) / sampling_rate + mbt = MadmomBeatTracking(num_threads=num_threads, parallel_workers=parallel_workers) + return mbt.process(audio_data), len(audio_data) / sampling_rate def extract_mini_beat_from_beat_arr(beat_arr, audio_len_sec, mini_beat_div_n=32): @@ -152,10 +161,15 @@ def extract_mini_beat_from_beat_arr(beat_arr, audio_len_sec, mini_beat_div_n=32) return mini_beat_pos_t -def extract_mini_beat_from_audio_path(audio_path, sampling_rate=44100, mini_beat_div_n=32): +def extract_mini_beat_from_audio_path(audio_path, sampling_rate=44100, mini_beat_div_n=32, parallel_workers=3, num_threads=3): """ Wrapper of extracting mini beats from audio path. """ logger.debug("Extracting beat with madmom") - beat_arr, audio_len_sec = extract_beat_with_madmom(audio_path, sampling_rate=sampling_rate) + beat_arr, audio_len_sec = extract_beat_with_madmom( + audio_path, + sampling_rate=sampling_rate, + parallel_workers=parallel_workers, + num_threads=num_threads + ) logger.debug("Extracting mini beat") return extract_mini_beat_from_beat_arr(beat_arr, audio_len_sec, mini_beat_div_n=mini_beat_div_n) diff --git a/omnizart/feature/wrapper_func.py b/omnizart/feature/wrapper_func.py index 94a9793..98f9e3f 100644 --- a/omnizart/feature/wrapper_func.py +++ b/omnizart/feature/wrapper_func.py @@ -28,7 +28,7 @@ def get_frame_by_time(time_sec, sampling_rate=44100, hop_size=256): return int(round(time_sec * sampling_rate / hop_size)) -def extract_patch_cqt(audio_path, sampling_rate=44100, hop_size=256): +def extract_patch_cqt(audio_path, sampling_rate=44100, hop_size=256, beat_tracker_num_threads=3, beat_tracker_parallel_workers=3): """Extract patched CQT feature. Leverages mini-beat information to determine the bound of each @@ -51,7 +51,12 @@ def extract_patch_cqt(audio_path, sampling_rate=44100, hop_size=256): omnizart.feature.beat_for_drum.extract_mini_beat_from_audio_path: Function for extracting mini-beat. """ cqt_ext = cqt.extract_cqt(audio_path, sampling_rate=sampling_rate, a_hop=hop_size) - mini_beat_arr = b4d.extract_mini_beat_from_audio_path(audio_path, sampling_rate=sampling_rate) + mini_beat_arr = b4d.extract_mini_beat_from_audio_path( + audio_path, + sampling_rate=sampling_rate, + num_threads=beat_tracker_num_threads, + parallel_workers=beat_tracker_parallel_workers + ) m_beat_cqt_patch_list = [] for m_beat_t_cur in mini_beat_arr: