Skip to content

Commit

Permalink
fixed matrix issue in correlation function + error handling (#3030)
Browse files Browse the repository at this point in the history
* fixed matrix issue in correlation function + error handling

* fixed syntax error in correlation function documentation

* changed error to syntax error

* added test cases for error handling

* added test cases for error handling v2

* fixed issue in test cases for error handling v2

* fixed issue in test cases for error handling v2.1

* fixed issue in test cases for error handling v2.2

* removed math.matrix examples

* removed redundant code
  • Loading branch information
vrushaket authored Sep 20, 2023
1 parent 1edc38c commit abf9c9f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -233,5 +233,6 @@ BuildTools <[email protected]>
Anik Patel <[email protected]>
Vrushaket Chaudhari <[email protected]>
Praise Nnamonu <[email protected]>
vrushaket <[email protected]>

# Generated by tools/update-authors.js
21 changes: 16 additions & 5 deletions src/function/statistics/corr.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { factory } from '../../utils/factory.js'

const name = 'corr'
const dependencies = ['typed', 'matrix', 'mean', 'sqrt', 'sum', 'add', 'subtract', 'multiply', 'pow', 'divide']

Expand All @@ -13,8 +14,8 @@ export const createCorr = /* #__PURE__ */ factory(name, dependencies, ({ typed,
* Examples:
*
* math.corr([1, 2, 3, 4, 5], [4, 5, 6, 7, 8]) // returns 1
* math.corr([1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]) // returns 0.9569941688503644
* math.corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8], [1, 2, 3, 4, 5]])) // returns DenseMatrix [0.9569941688503644, 1]
* math.corr([1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]) //returns 0.9569941688503644
* math.corr([[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]],[[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]]) // returns [1,1]
*
* See also:
*
Expand All @@ -28,8 +29,9 @@ export const createCorr = /* #__PURE__ */ factory(name, dependencies, ({ typed,
'Array, Array': function (A, B) {
return _corr(A, B)
},
'Matrix, Matrix': function (xMatrix, yMatrix) {
return matrix(_corr(xMatrix.toArray(), yMatrix.toArray()))
'Matrix, Matrix': function (A, B) {
const res = _corr(A.toArray(), B.toArray())
return Array.isArray(res) ? matrix(res) : res
}
})
/**
Expand All @@ -40,13 +42,22 @@ export const createCorr = /* #__PURE__ */ factory(name, dependencies, ({ typed,
* @private
*/
function _corr (A, B) {
const correlations = []
if (Array.isArray(A[0]) && Array.isArray(B[0])) {
const correlations = []
if (A.length !== B.length) {
throw new SyntaxError('Dimension mismatch. Array A and B must have the same length.')
}
for (let i = 0; i < A.length; i++) {
if (A[i].length !== B[i].length) {
throw new SyntaxError('Dimension mismatch. Array A and B must have the same number of elements.')
}
correlations.push(correlation(A[i], B[i]))
}
return correlations
} else {
if (A.length !== B.length) {
throw new SyntaxError('Dimension mismatch. Array A and B must have the same number of elements.')
}
return correlation(A, B)
}
}
Expand Down
30 changes: 27 additions & 3 deletions test/unit-tests/function/statistics/corr.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,42 @@ const corr = math.corr
const BigNumber = math.BigNumber

describe('correlation', function () {
it('should return the correlation coefficient from an array', function () {
it('should return the correlation coefficient from array', function () {
assert.strictEqual(corr([new BigNumber(1), new BigNumber(2.2), new BigNumber(3), new BigNumber(4.8), new BigNumber(5)], [new BigNumber(4), new BigNumber(5.3), new BigNumber(6.6), new BigNumber(7), new BigNumber(8)]).toNumber(), 0.9569941688503653)
assert.strictEqual(corr([1, 2, 3, 4, 5], [4, 5, 6, 7, 8]), 1)
assert.strictEqual(corr([1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]), 0.9569941688503644)
assert.deepStrictEqual(corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8], [1, 2, 3, 4, 5]]))._data, [0.9569941688503644, 1])
assert.deepStrictEqual(corr([[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]], [[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]]), [1, 1])
})

it('should throw an error if called with invalid number of arguments', function () {
it('should return the correlation coefficient from matrix', function () {
assert.strictEqual((corr(math.matrix([2, 4, 6, 8]), math.matrix([1, 2, 3, 6]))), 0.9561828874675149)
assert.deepStrictEqual(corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8], [1, 2, 3, 4, 5]])).toArray(), [0.9569941688503644, 1])
})

it('should throw an error if called with zero arguments', function () {
assert.throws(function () { corr() })
})

it('should throw an error if called with an empty array', function () {
assert.throws(function () { corr([]) })
})

it('should throw an error if called with different number of arguments', function () {
assert.throws(function () { corr(math.matrix([2, 4, 6, 8]), math.matrix([1, 2, 3])) })
})

it('should throw an error if called with number of arguments do not have same size', function () {
assert.throws(function () { corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8]])) })
})

it('should throw an error if called with different number of arguments', function () {
assert.throws(function () { corr([[1, 2, 3, 4, 5], [4, 5, 6, 7, 8], [9, 10, 11, 12]], [[1, 2, 3, 4, 5], [4, 5, 6, 7, 8]]) })
})

it('should throw an error if called with number of arguments do not have same size', function () {
assert.throws(function () { corr([[1, 2, 3, 4, 5], [4, 5, 6, 7]], [[1, 2, 3, 4, 5], []]) })
})
it('should throw an error if called with number of arguments do not have same size', function () {
assert.throws(function () { corr([1, 2, 3, 4, 5], [1, 2, 3, 4]) })
})
})

0 comments on commit abf9c9f

Please sign in to comment.