From 14bf689c983c68c7e870b970e1bcce79c791c3c9 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 8 Dec 2024 14:09:42 +0200 Subject: [PATCH] Fix pyannote processor `post_process_speaker_diarization` (#1082) --- .../pyannote/feature_extraction_pyannote.js | 56 +++++++++++++++++ src/models/pyannote/processing_pyannote.js | 61 +++---------------- 2 files changed, 63 insertions(+), 54 deletions(-) diff --git a/src/models/pyannote/feature_extraction_pyannote.js b/src/models/pyannote/feature_extraction_pyannote.js index 74b40fec9..a0044e159 100644 --- a/src/models/pyannote/feature_extraction_pyannote.js +++ b/src/models/pyannote/feature_extraction_pyannote.js @@ -1,5 +1,6 @@ import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js'; import { Tensor } from '../../utils/tensor.js'; +import { max, softmax } from '../../utils/maths.js'; export class PyAnnoteFeatureExtractor extends FeatureExtractor { @@ -25,4 +26,59 @@ export class PyAnnoteFeatureExtractor extends FeatureExtractor { }; } + /** + * NOTE: Can return fractional values. `Math.ceil` will ensure correct value. + * @param {number} samples The number of frames in the audio. + * @returns {number} The number of frames in the audio. + */ + samples_to_frames(samples) { + return ((samples - this.config.offset) / this.config.step); + } + + /** + * Post-processes the speaker diarization logits output by the model. + * @param {import('../../utils/tensor.js').Tensor} logits The speaker diarization logits output by the model. + * @param {number} num_samples Number of samples in the input audio. + * @returns {Array>} The post-processed speaker diarization results. + */ + post_process_speaker_diarization(logits, num_samples) { + const ratio = ( + num_samples / this.samples_to_frames(num_samples) + ) / this.config.sampling_rate; + + const results = []; + for (const scores of logits.tolist()) { + const accumulated_segments = []; + + let current_speaker = -1; + for (let i = 0; i < scores.length; ++i) { + const probabilities = softmax(scores[i]); + const [score, id] = max(probabilities); + const [start, end] = [i, i + 1]; + + if (id !== current_speaker) { + // Speaker has changed + current_speaker = id; + accumulated_segments.push({ id, start, end, score }); + } else { + // Continue the current segment + accumulated_segments.at(-1).end = end; + accumulated_segments.at(-1).score += score; + } + } + + results.push(accumulated_segments.map( + // Convert frame-space to time-space + // and compute the confidence + ({ id, start, end, score }) => ({ + id, + start: start * ratio, + end: end * ratio, + confidence: score / (end - start), + }) + )); + } + return results; + } + } diff --git a/src/models/pyannote/processing_pyannote.js b/src/models/pyannote/processing_pyannote.js index cf66251a8..e5fff9cb6 100644 --- a/src/models/pyannote/processing_pyannote.js +++ b/src/models/pyannote/processing_pyannote.js @@ -1,9 +1,8 @@ import { Processor } from '../../base/processing_utils.js'; -import { AutoFeatureExtractor } from '../auto/feature_extraction_auto.js'; -import { max, softmax } from '../../utils/maths.js'; +import { PyAnnoteFeatureExtractor } from './feature_extraction_pyannote.js'; export class PyAnnoteProcessor extends Processor { - static feature_extractor_class = AutoFeatureExtractor + static feature_extractor_class = PyAnnoteFeatureExtractor /** * Calls the feature_extractor function with the given audio input. @@ -14,58 +13,12 @@ export class PyAnnoteProcessor extends Processor { return await this.feature_extractor(audio) } - /** - * NOTE: Can return fractional values. `Math.ceil` will ensure correct value. - * @param {number} samples The number of frames in the audio. - * @returns {number} The number of frames in the audio. - */ - samples_to_frames(samples) { - return ((samples - this.config.offset) / this.config.step); + /** @type {PyAnnoteFeatureExtractor['post_process_speaker_diarization']} */ + post_process_speaker_diarization(...args) { + return /** @type {PyAnnoteFeatureExtractor} */(this.feature_extractor).post_process_speaker_diarization(...args); } - /** - * Post-processes the speaker diarization logits output by the model. - * @param {import('../../utils/tensor.js').Tensor} logits The speaker diarization logits output by the model. - * @param {number} num_samples Number of samples in the input audio. - * @returns {Array>} The post-processed speaker diarization results. - */ - post_process_speaker_diarization(logits, num_samples) { - const ratio = ( - num_samples / this.samples_to_frames(num_samples) - ) / this.config.sampling_rate; - - const results = []; - for (const scores of logits.tolist()) { - const accumulated_segments = []; - - let current_speaker = -1; - for (let i = 0; i < scores.length; ++i) { - const probabilities = softmax(scores[i]); - const [score, id] = max(probabilities); - const [start, end] = [i, i + 1]; - - if (id !== current_speaker) { - // Speaker has changed - current_speaker = id; - accumulated_segments.push({ id, start, end, score }); - } else { - // Continue the current segment - accumulated_segments.at(-1).end = end; - accumulated_segments.at(-1).score += score; - } - } - - results.push(accumulated_segments.map( - // Convert frame-space to time-space - // and compute the confidence - ({ id, start, end, score }) => ({ - id, - start: start * ratio, - end: end * ratio, - confidence: score / (end - start), - }) - )); - } - return results; + get sampling_rate() { + return this.feature_extractor.config.sampling_rate; } }