Skip to content

Commit

Permalink
Optimize FFT (#766)
Browse files Browse the repository at this point in the history
* Optimize FFT for real transforms

* Throw error if power is not specified

huggingface/transformers#27772
  • Loading branch information
xenova authored May 22, 2024
1 parent 8963720 commit 8d166ca
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 42 deletions.
9 changes: 7 additions & 2 deletions src/utils/audio.js
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,13 @@ export function spectrogram(
throw new Error("hop_length must be greater than zero");
}

if (power === null && mel_filters !== null) {
throw new Error(
"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram. " +
"Specify `power` to fix this issue."
);
}

if (center) {
if (pad_mode !== 'reflect') {
throw new Error(`pad_mode="${pad_mode}" not implemented yet.`)
Expand Down Expand Up @@ -547,8 +554,6 @@ export function spectrogram(
magnitudes[i] = row;
}

// TODO what should happen if power is None?
// https://github.com/huggingface/transformers/issues/27772
if (power !== null && power !== 2) {
// slight optimization to not sqrt
const pow = 2 / power; // we use 2 since we already squared
Expand Down
93 changes: 53 additions & 40 deletions src/utils/maths.js
Original file line number Diff line number Diff line change
Expand Up @@ -364,20 +364,6 @@ class P2FFT {
return res;
}

/**
* Completes the spectrum by adding its mirrored negative frequency components.
* @param {Float64Array} spectrum The input spectrum.
* @returns {void}
*/
completeSpectrum(spectrum) {
const size = this._csize;
const half = size >>> 1;
for (let i = 2; i < half; i += 2) {
spectrum[size - i] = spectrum[i];
spectrum[size - i + 1] = -spectrum[i + 1];
}
}

/**
* Performs a Fast Fourier Transform (FFT) on the given input data and stores the result in the output buffer.
*
Expand Down Expand Up @@ -466,6 +452,7 @@ class P2FFT {
}

// Loop through steps in decreasing order
const table = this.table;
for (step >>= 2; step >= 2; step >>= 2) {
len = (size / step) << 1;
const quarterLen = len >>> 2;
Expand All @@ -490,18 +477,18 @@ class P2FFT {
const Dr = out[D];
const Di = out[D + 1];

const tableBr = this.table[k];
const tableBi = inv * this.table[k + 1];
const tableBr = table[k];
const tableBi = inv * table[k + 1];
const MBr = Br * tableBr - Bi * tableBi;
const MBi = Br * tableBi + Bi * tableBr;

const tableCr = this.table[2 * k];
const tableCi = inv * this.table[2 * k + 1];
const tableCr = table[2 * k];
const tableCi = inv * table[2 * k + 1];
const MCr = Cr * tableCr - Ci * tableCi;
const MCi = Cr * tableCi + Ci * tableCr;

const tableDr = this.table[3 * k];
const tableDi = inv * this.table[3 * k + 1];
const tableDr = table[3 * k];
const tableDi = inv * table[3 * k + 1];
const MDr = Dr * tableDr - Di * tableDi;
const MDi = Dr * tableDi + Di * tableDr;

Expand Down Expand Up @@ -634,18 +621,18 @@ class P2FFT {
}
}

// TODO: Optimize once https://github.com/indutny/fft.js/issues/25 is fixed
// Loop through steps in decreasing order
const table = this.table;
for (step >>= 2; step >= 2; step >>= 2) {
len = (size / step) << 1;
const quarterLen = len >>> 2;
const halfLen = len >>> 1;
const quarterLen = halfLen >>> 1;
const hquarterLen = quarterLen >>> 1;

// Loop through offsets in the data
for (outOff = 0; outOff < size; outOff += len) {
// Full case
const limit = outOff + quarterLen - 1;
for (let i = outOff, k = 0; i < limit; i += 2, k += step) {
const A = i;
for (let i = 0, k = 0; i <= hquarterLen; i += 2, k += step) {
const A = outOff + i;
const B = A + quarterLen;
const C = B + quarterLen;
const D = C + quarterLen;
Expand All @@ -660,26 +647,30 @@ class P2FFT {
const Dr = out[D];
const Di = out[D + 1];

const tableBr = this.table[k];
const tableBi = inv * this.table[k + 1];
// Middle values
const MAr = Ar;
const MAi = Ai;

const tableBr = table[k];
const tableBi = inv * table[k + 1];
const MBr = Br * tableBr - Bi * tableBi;
const MBi = Br * tableBi + Bi * tableBr;

const tableCr = this.table[2 * k];
const tableCi = inv * this.table[2 * k + 1];
const tableCr = table[2 * k];
const tableCi = inv * table[2 * k + 1];
const MCr = Cr * tableCr - Ci * tableCi;
const MCi = Cr * tableCi + Ci * tableCr;

const tableDr = this.table[3 * k];
const tableDi = inv * this.table[3 * k + 1];
const tableDr = table[3 * k];
const tableDi = inv * table[3 * k + 1];
const MDr = Dr * tableDr - Di * tableDi;
const MDi = Dr * tableDi + Di * tableDr;

// Pre-Final values
const T0r = Ar + MCr;
const T0i = Ai + MCi;
const T1r = Ar - MCr;
const T1i = Ai - MCi;
const T0r = MAr + MCr;
const T0i = MAi + MCi;
const T1r = MAr - MCr;
const T1i = MAi - MCi;
const T2r = MBr + MDr;
const T2i = MBi + MDi;
const T3r = inv * (MBr - MDr);
Expand All @@ -690,13 +681,35 @@ class P2FFT {
out[A + 1] = T0i + T2i;
out[B] = T1r + T3i;
out[B + 1] = T1i - T3r;
out[C] = T0r - T2r;
out[C + 1] = T0i - T2i;
out[D] = T1r - T3i;
out[D + 1] = T1i + T3r;

// Output final middle point
if (i === 0) {
out[C] = T0r - T2r;
out[C + 1] = T0i - T2i;
continue;
}

// Do not overwrite ourselves
if (i === hquarterLen)
continue;

const SA = outOff + quarterLen - i;
const SB = outOff + halfLen - i;

out[SA] = T1r - inv * T3i;
out[SA + 1] = -T1i - inv * T3r;
out[SB] = T0r - inv * T2r;
out[SB + 1] = -T0i + inv * T2i;
}
}
}

// Complete the spectrum by adding its mirrored negative frequency components.
const half = size >>> 1;
for (let i = 2; i < half; i += 2) {
out[size - i] = out[i];
out[size - i + 1] = -out[i + 1];
}
}

/**
Expand Down

0 comments on commit 8d166ca

Please sign in to comment.