Skip to content

Commit

Permalink
Implement numerically stable log_softmax() (#812)
Browse files Browse the repository at this point in the history
* Implement numerically stable log_softmax()

* Add unit tests

* Update src/utils/maths.js

---------

Co-authored-by: Joshua Lochner <[email protected]>
  • Loading branch information
taha-yassine and xenova authored Jul 1, 2024
1 parent fc34517 commit 75f557b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
17 changes: 13 additions & 4 deletions src/utils/maths.js
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,20 @@ export function softmax(arr) {
* @returns {T} The resulting log_softmax array.
*/
export function log_softmax(arr) {
// Compute the softmax values
const softmaxArr = softmax(arr);
// Compute the maximum value in the array
const maxVal = max(arr)[0];

// Compute the sum of the exponentials
let sumExps = 0;
for(let i = 0; i < arr.length; ++i) {
sumExps += Math.exp(arr[i] - maxVal);
}

// Apply log formula to each element
const logSoftmaxArr = softmaxArr.map(x => Math.log(x));
// Compute the log of the sum
const logSum = Math.log(sumExps);

// Compute the softmax values
const logSoftmaxArr = arr.map(x => x - maxVal - logSum);

return /** @type {T} */(logSoftmaxArr);
}
Expand Down
19 changes: 18 additions & 1 deletion tests/maths.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import { compare } from './test_utils.js';

import { getFile } from '../src/utils/hub.js';
import { FFT, medianFilter, bankers_round } from '../src/utils/maths.js';
import { FFT, medianFilter, bankers_round, log_softmax } from '../src/utils/maths.js';


const fft = (arr, complex = false) => {
Expand Down Expand Up @@ -136,4 +136,21 @@ describe('Mathematical operations', () => {
});
}
});

describe('log softmax', () => {
// Should match output of scipy log_softmax
it('should compute log softmax correctly for usual values', () => {
const input = [0, 1, 2, 3];
const expected = [-3.4401896985611953, -2.4401896985611953, -1.4401896985611953, -0.44018969856119533];
const output = log_softmax(input);
compare(output, expected, 1e-13);
});

it('should compute log softmax correctly for values with large differences', () => {
const input = [1000, 1];
const expected = [0, -999];
const output = log_softmax(input);
compare(output, expected, 1e-13);
});
});
});

0 comments on commit 75f557b

Please sign in to comment.