From 1c220c7c7a11141c149516849da451961125291f Mon Sep 17 00:00:00 2001 From: jzajic Date: Mon, 5 Feb 2018 13:14:32 +0100 Subject: [PATCH 1/2] array optimization - part 1 - remove duplicate and unused code --- .../nupic/util/AbstractFlatMatrix.java | 24 +----- .../org/numenta/nupic/util/ArrayUtils.java | 76 ----------------- .../org/numenta/nupic/util/NamedTuple.java | 26 +----- .../numenta/nupic/util/ArrayUtilsTest.java | 44 +--------- .../org/numenta/nupic/util/MatrixTest.java | 4 +- .../numenta/nupic/util/NamedTupleTest.java | 81 ------------------- .../nupic/util/SparseObjectMatrixTest.java | 4 +- 7 files changed, 8 insertions(+), 251 deletions(-) diff --git a/src/main/java/org/numenta/nupic/util/AbstractFlatMatrix.java b/src/main/java/org/numenta/nupic/util/AbstractFlatMatrix.java index b9e42124..6f20a0ef 100644 --- a/src/main/java/org/numenta/nupic/util/AbstractFlatMatrix.java +++ b/src/main/java/org/numenta/nupic/util/AbstractFlatMatrix.java @@ -23,7 +23,6 @@ package org.numenta.nupic.util; import java.io.Serializable; -import java.lang.reflect.Array; import java.util.Arrays; /** @@ -112,7 +111,7 @@ protected void checkDims(int[] index) { for(int i = 0;i < index.length - 1;i++) { if(index[i] >= dimensions[i]) { throw new IllegalArgumentException("Specified coordinates exceed the configured array dimensions " + - print1DArray(index) + " > " + print1DArray(dimensions)); + Arrays.toString(index) + " > " + Arrays.toString(dimensions)); } } } @@ -179,27 +178,6 @@ public static int[] reverse(int[] input) { return retVal; } - /** - * Prints the specified array to a returned String. - * - * @param aObject the array object to print. - * @return the array in string form suitable for display. - */ - public static String print1DArray(Object aObject) { - if (aObject.getClass().isArray()) { - if (aObject instanceof Object[]) // can we cast to Object[] - return Arrays.toString((Object[]) aObject); - else { // we can't cast to Object[] - case of primitive arrays - int length = Array.getLength(aObject); - Object[] objArr = new Object[length]; - for (int i=0; i Object[] interleave(F first, S second) { - int flen, slen; - Object[] retVal = new Object[(flen = Array.getLength(first)) + (slen = Array.getLength(second))]; - for(int i = 0, j = 0, k = 0;i < flen || j < slen;) { - if(i < flen) { - retVal[k++] = Array.get(first, i++); - } - if(j < slen) { - retVal[k++] = Array.get(second, j++); - } - } - - return retVal; - } } diff --git a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java index 0195999c..4531a578 100644 --- a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java +++ b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java @@ -115,49 +115,7 @@ public void testArgsort() { args = ArrayUtils.argsort(new int[] { 11, 2, 3, 7, 0 }, 0, 3); assertTrue(Arrays.equals(new int[] {4, 1, 2}, args)); } - - @Test - public void testShape() { - int[][] inputPattern = { { 2, 3, 4, 5 }, { 6, 7, 8, 9} }; - int[] shape = ArrayUtils.shape(inputPattern); - assertTrue(Arrays.equals(new int[] { 2, 4 }, shape)); - } - - @Test - public void testReshape() { - int[][] test = { - { 0, 1, 2, 3, 4, 5 }, - { 6, 7, 8, 9, 10, 11 } - }; - - int[][] expected = { - { 0, 1, 2 }, - { 3, 4, 5 }, - { 6, 7, 8 }, - { 9, 10, 11 } - }; - - int[][] result = ArrayUtils.reshape(test, 3); - for(int i = 0;i < result.length;i++) { - for(int j = 0;j < result[i].length;j++) { - assertEquals(expected[i][j], result[i][j]); - } - } - - // Unhappy case - try { - ArrayUtils.reshape(test, 5); - }catch(Exception e) { - assertTrue(e instanceof IllegalArgumentException); - assertEquals("12 is not evenly divisible by 5", e.getMessage()); - } - - // Test zero-length case - int[] result4 = ArrayUtils.unravel(new int[0][]); - assertNotNull(result4); - assertTrue(result4.length == 0); - } - + @Test public void testRavelAndUnRavel() { int[] test = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 }; diff --git a/src/test/java/org/numenta/nupic/util/MatrixTest.java b/src/test/java/org/numenta/nupic/util/MatrixTest.java index b6982c49..5e2a0d67 100644 --- a/src/test/java/org/numenta/nupic/util/MatrixTest.java +++ b/src/test/java/org/numenta/nupic/util/MatrixTest.java @@ -49,7 +49,7 @@ public void testBitSetMatrixSet() { } assertArrayEquals(expected, asDense(bsm)); - assertEquals(Arrays.toString(expected), FlatArrayMatrix.print1DArray(asDense(bsm))); + assertEquals(Arrays.toString(expected), Arrays.toString(asDense(bsm))); } @Test @@ -63,7 +63,7 @@ public void testFlatArrayMatrixSet() { } assertArrayEquals(expected, asDense(fam)); - assertEquals(Arrays.toString(expected), FlatArrayMatrix.print1DArray(asDense(fam))); + assertEquals(Arrays.toString(expected), Arrays.toString(asDense(fam))); } private Object[] asDense(FlatMatrix matrix) { diff --git a/src/test/java/org/numenta/nupic/util/NamedTupleTest.java b/src/test/java/org/numenta/nupic/util/NamedTupleTest.java index 77089ca8..8753daa8 100644 --- a/src/test/java/org/numenta/nupic/util/NamedTupleTest.java +++ b/src/test/java/org/numenta/nupic/util/NamedTupleTest.java @@ -133,87 +133,6 @@ public void testEquality() { assertNotEquals(nt, nt2); } - @Test - public void testInterleave() { - String[] f = { "0" }; - double[] s = { 0.8 }; - - // Test most simple interleave of equal length arrays - Object[] result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals(0.8, result[1]); - - // Test simple interleave of larger array - f = new String[] { "0", "1" }; - s = new double[] { 0.42, 2.5 }; - result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals(0.42, result[1]); - assertEquals("1", result[2]); - assertEquals(2.5, result[3]); - - // Test complex interleave of larger array - f = new String[] { "0", "1", "bob", "harry", "digit", "temperature" }; - s = new double[] { 0.42, 2.5, .001, 1e-2, 34.0, .123 }; - result = NamedTuple.interleave(f, s); - for(int i = 0, j = 0;j < result.length;i++, j+=2) { - assertEquals(f[i], result[j]); - assertEquals(s[i], result[j + 1]); - } - - // Test interleave with zero length of first - f = new String[0]; - s = new double[] { 0.42, 2.5 }; - result = NamedTuple.interleave(f, s); - assertEquals(0.42, result[0]); - assertEquals(2.5, result[1]); - - // Test interleave with zero length of second - f = new String[] { "0", "1" }; - s = new double[0]; - result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals("1", result[1]); - - // Test complex unequal length: left side smaller - f = new String[] { "0", "1", "bob" }; - s = new double[] { 0.42, 2.5, .001, 1e-2, 34.0, .123 }; - result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals(0.42, result[1]); - assertEquals("1", result[2]); - assertEquals(2.5, result[3]); - assertEquals("bob", result[4]); - assertEquals(.001, result[5]); - assertEquals(1e-2, result[6]); - assertEquals(34.0, result[7]); - assertEquals(.123, result[8]); - - // Test complex unequal length: right side smaller - f = new String[] { "0", "1", "bob", "harry", "digit", "temperature" }; - s = new double[] { 0.42, 2.5, .001 }; - result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals(0.42, result[1]); - assertEquals("1", result[2]); - assertEquals(2.5, result[3]); - assertEquals("bob", result[4]); - assertEquals(.001, result[5]); - assertEquals("harry", result[6]); - assertEquals("digit", result[7]); - assertEquals("temperature", result[8]); - - // Negative testing - try { - f = null; - s = new double[] { 0.42, 2.5, .001 }; - result = NamedTuple.interleave(f, s); - fail(); - }catch(Exception e) { - assertEquals(NullPointerException.class, e.getClass()); - } - } - @Test public void testGetValues() { Set set = new LinkedHashSet<>(); diff --git a/src/test/java/org/numenta/nupic/util/SparseObjectMatrixTest.java b/src/test/java/org/numenta/nupic/util/SparseObjectMatrixTest.java index 882144f7..88867f5b 100644 --- a/src/test/java/org/numenta/nupic/util/SparseObjectMatrixTest.java +++ b/src/test/java/org/numenta/nupic/util/SparseObjectMatrixTest.java @@ -25,6 +25,8 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import java.util.Arrays; + import org.junit.Test; public class SparseObjectMatrixTest { @@ -38,7 +40,7 @@ public void testGetDimensionMultiples() { sm = new SparseObjectMatrix(new int[] { 1, 2, 3, 4, 5 }); dm = sm.getDimensionMultiples(); - assertEquals(ArrayUtils.print1DArray(dm), "[120, 60, 20, 5, 1]"); + assertEquals(Arrays.toString(dm), "[120, 60, 20, 5, 1]"); } /** From 6c122488f0bfcf1542b31bce63e8c2a7d5d35ea2 Mon Sep 17 00:00:00 2001 From: jzajic Date: Mon, 5 Feb 2018 22:39:20 +0100 Subject: [PATCH 2/2] array optimization - part 2 - replace Array.get --- .../org/numenta/nupic/util/ArrayUtils.java | 91 +++++++++++++--- .../nupic/util/SparseBinaryMatrix.java | 20 ++-- .../numenta/nupic/util/ArrayUtilsTest.java | 100 ++++++++++++++++++ 3 files changed, 189 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/numenta/nupic/util/ArrayUtils.java b/src/main/java/org/numenta/nupic/util/ArrayUtils.java index 1962051d..423db2bd 100644 --- a/src/main/java/org/numenta/nupic/util/ArrayUtils.java +++ b/src/main/java/org/numenta/nupic/util/ArrayUtils.java @@ -84,7 +84,7 @@ public static int product(int[] dims) { return retVal; } - + /** * Returns an array containing the successive elements of each * argument array as in [ first[0], second[0], first[1], second[1], ... ]. @@ -100,16 +100,16 @@ public static Object[] interleave(F first, S second) { Object[] retVal = new Object[(flen = Array.getLength(first)) + (slen = Array.getLength(second))]; for(int i = 0, j = 0, k = 0;i < flen || j < slen;) { if(i < flen) { - retVal[k++] = Array.get(first, i++); + retVal[k++] = getValue(first, i++); } if(j < slen) { - retVal[k++] = Array.get(second, j++); + retVal[k++] = getValue(second, j++); } } - + return retVal; } - + /** *

* Return a new double[] containing the difference of each element and its @@ -2046,28 +2046,95 @@ public static int[] tail(int[] original) { * @param indexes */ public static void setValue(Object array, int value, int... indexes) { - if (indexes.length == 1) { + if(indexes.length == 1) { ((int[])array)[indexes[0]] = value; } else { - setValue(Array.get(array, indexes[0]), value, tail(indexes)); + setValue(ArrayUtils.getSlice(array, indexes[0]), value, tail(indexes)); } } - + /** * Get value for array at specified position indexes * * @param array * @param indexes */ + public static int getIntValue(Object array, int... indexes) { + Object slice = array; + if(indexes.length > 1) { + for(int i = 0;i < indexes.length - 1;i++) { + slice = ((Object[])slice)[indexes[i]]; + } + } + return ((int[])slice)[indexes[indexes.length - 1]]; + } + public static Object getValue(Object array, int... indexes) { Object slice = array; - for(int i = 0;i < indexes.length;i++) { - slice = Array.get(slice, indexes[i]); + if(indexes.length > 1) { + for(int i = 0;i < indexes.length - 1;i++) { + slice = get(slice, i); + } } - - return slice; + return get(slice, indexes[indexes.length - 1]); } + /** + * Get slice for array at specified position + * indexes + * + * @param array + * @param indexes + */ + public static Object getSlice(Object array, int... indexes) { + Object slice = array; + if(indexes.length > 1) { + for(int i = 0;i < indexes.length - 1;i++) { + slice = ((Object[])slice)[indexes[i]]; + } + } + return ((Object[])slice)[indexes[indexes.length - 1]]; + } + + /** + * Gets an element of an array. Primitive elements will be wrapped in the + * corresponding class type. + * + * @param array + * the array to access + * @param index + * the array index to access + * @return the element at array[index] + * @throws IllegalArgumentException + * if array is not an array + * @throws NullPointerException + * if array is null + * @throws ArrayIndexOutOfBoundsException + * if index is out of bounds + */ + public static Object get(Object array, int index) { + if(array instanceof Object[]) + return ((Object[])array)[index]; + if(array instanceof boolean[]) + return ((boolean[])array)[index] ? Boolean.TRUE : Boolean.FALSE; + if(array instanceof byte[]) + return new Byte(((byte[])array)[index]); + if(array instanceof char[]) + return new Character(((char[])array)[index]); + if(array instanceof short[]) + return new Short(((short[])array)[index]); + if(array instanceof int[]) + return new Integer(((int[])array)[index]); + if(array instanceof long[]) + return new Long(((long[])array)[index]); + if(array instanceof float[]) + return new Float(((float[])array)[index]); + if(array instanceof double[]) + return new Double(((double[])array)[index]); + if(array == null) + throw new NullPointerException(); + throw new IllegalArgumentException(); + } /** *Assigns the specified int value to each element of the specified any dimensional array diff --git a/src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java b/src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java index 69effcd2..40c1b462 100644 --- a/src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java +++ b/src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java @@ -83,18 +83,18 @@ private void back(int val, int... coordinates) { * * @param coordinates the coordinates which specify the returned array * @return the array specified - * @throws IllegalArgumentException if the specified coordinates address - * an actual value instead of the array holding it. + * @throws IllegalArgumentException if the specified coordinates address + * an actual value instead of the array holding it. */ @Override public Object getSlice(int... coordinates) { - Object slice = ArrayUtils.getValue(this.backingArray, coordinates); - //Ensure return value is of type Array - if(!slice.getClass().isArray()) { - sliceError(coordinates); + try { + return ArrayUtils.getSlice(this.backingArray, coordinates); + } catch(ClassCastException e) { + throw new IllegalArgumentException( + "This method only returns the array holding the specified maximum index: " + + Arrays.toString(dimensions)); } - - return slice; } /** @@ -177,7 +177,7 @@ public AbstractSparseBinaryMatrix set(int[] indexes, int[] values) { */ public void clearStatistics(int row) { this.setTrueCount(row, 0); - int[] slice = (int[])Array.get(backingArray, row); + int[] slice = (int[])ArrayUtils.getSlice(backingArray, row); Arrays.fill(slice, 0); } @@ -194,7 +194,7 @@ public Integer get(int index) { return Array.getInt(this.backingArray, index); } - else return (Integer) ArrayUtils.getValue(this.backingArray, coordinates); + else return (Integer) ArrayUtils.getIntValue(this.backingArray, coordinates); } @Override diff --git a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java index 4531a578..5ef59ec9 100644 --- a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java +++ b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java @@ -34,12 +34,112 @@ import java.util.Arrays; import java.util.List; +import org.junit.Assert; import org.junit.Test; import org.numenta.nupic.model.Cell; import org.numenta.nupic.model.Column; public class ArrayUtilsTest { + @Test + public void testJavaArrayBehavior() { + Object array = Array.newInstance(int.class, 1); + Assert.assertTrue(array instanceof int[]); + array = Array.newInstance(int.class, 2); + Assert.assertTrue(array instanceof int[]); + Assert.assertFalse(array instanceof int[][]); + array = Array.newInstance(int.class, 3); + Assert.assertTrue(array instanceof int[]); + Assert.assertFalse(array instanceof int[][]); + Assert.assertFalse(array instanceof int[][][]); + } + + @Test + public void testGet() { + int[] singleDimArray = new int[2]; + singleDimArray[0]=0; + singleDimArray[1]=1; + assertEquals(0, ArrayUtils.getIntValue(singleDimArray, 0)); + assertEquals(1, ArrayUtils.getIntValue(singleDimArray, 1)); + + int[][] twoDimArray = new int[2][2]; + twoDimArray[0]=new int[2]; + twoDimArray[1]=new int[2]; + twoDimArray[0][0]=0; + twoDimArray[0][1]=1; + twoDimArray[1][0]=2; + twoDimArray[1][1]=3; + assertEquals(0, ArrayUtils.getIntValue(twoDimArray, 0, 0)); + assertEquals(1, ArrayUtils.getIntValue(twoDimArray, 0, 1)); + assertEquals(2, ArrayUtils.getIntValue(twoDimArray, 1, 0)); + assertEquals(3, ArrayUtils.getIntValue(twoDimArray, 1, 1)); + } + + @Test + public void testSetAndGet() { + int[] dimensions = new int[]{2, 2, 2}; + int[][][] threeDimArray = (int[][][]) Array.newInstance(int.class, dimensions); + System.out.println("input array:"); + for(int i = 0; i < 8; i++) { + int[] coords = computeCoordinates(i, dimensions); + System.out.println(Arrays.toString(coords)); + ArrayUtils.setValue(threeDimArray, i, coords); + } + for(int i = 0; i < 8; i++) { + int[] coords = computeCoordinates(i, dimensions); + assertEquals(i, ArrayUtils.getIntValue(threeDimArray, coords)); + } + + int[][] twoDimArray = (int[][]) ArrayUtils.getSlice(threeDimArray, new int[]{0}); + for(int i = 0; i < 4; i++) { + int[] coords = computeCoordinates(i, new int[]{2,2}); + assertEquals(i, ArrayUtils.getIntValue(twoDimArray, coords)); + } + int[] oneDimArray = (int[]) ArrayUtils.getSlice(threeDimArray, new int[]{0,0}); + assertEquals(0, oneDimArray[0]); + assertEquals(1, oneDimArray[1]); + oneDimArray = (int[]) ArrayUtils.getSlice(threeDimArray, new int[]{0,1}); + assertEquals(2, oneDimArray[0]); + assertEquals(3, oneDimArray[1]); + oneDimArray = (int[]) ArrayUtils.getSlice(threeDimArray, new int[]{1,0}); + assertEquals(4, oneDimArray[0]); + assertEquals(5, oneDimArray[1]); + oneDimArray = (int[]) ArrayUtils.getSlice(threeDimArray, new int[]{1,1}); + assertEquals(6, oneDimArray[0]); + assertEquals(7, oneDimArray[1]); + } + + private static int[] computeCoordinates(int index, int[] dimensions) { + int[] returnVal = new int[dimensions.length]; + int[] dimensionMultiples = initDimensionMultiples(dimensions); + int base = index; + for(int i = 0;i < dimensionMultiples.length; i++) { + int quotient = base / dimensionMultiples[i]; + base %= dimensionMultiples[i]; + returnVal[i] = quotient; + } + return returnVal; + } + + private static int[] initDimensionMultiples(int[] dimensions) { + int holder = 1; + int len = dimensions.length; + int[] dimensionMultiples = new int[dimensions.length]; + for(int i = 0;i < len;i++) { + holder *= (i == 0 ? 1 : dimensions[len - i]); + dimensionMultiples[len - 1 - i] = holder; + } + return dimensionMultiples; + } + + public static int[] reverse(int[] input) { + int[] retVal = new int[input.length]; + for(int i = input.length - 1, j = 0;i >= 0;i--, j++) { + retVal[j] = input[i]; + } + return retVal; + } + @Test public void testToBytes() { boolean[] ba = { true, true, };