diff --git a/src/main/java/org/numenta/nupic/util/AbstractFlatMatrix.java b/src/main/java/org/numenta/nupic/util/AbstractFlatMatrix.java index b9e42124..6f20a0ef 100644 --- a/src/main/java/org/numenta/nupic/util/AbstractFlatMatrix.java +++ b/src/main/java/org/numenta/nupic/util/AbstractFlatMatrix.java @@ -23,7 +23,6 @@ package org.numenta.nupic.util; import java.io.Serializable; -import java.lang.reflect.Array; import java.util.Arrays; /** @@ -112,7 +111,7 @@ protected void checkDims(int[] index) { for(int i = 0;i < index.length - 1;i++) { if(index[i] >= dimensions[i]) { throw new IllegalArgumentException("Specified coordinates exceed the configured array dimensions " + - print1DArray(index) + " > " + print1DArray(dimensions)); + Arrays.toString(index) + " > " + Arrays.toString(dimensions)); } } } @@ -179,27 +178,6 @@ public static int[] reverse(int[] input) { return retVal; } - /** - * Prints the specified array to a returned String. - * - * @param aObject the array object to print. - * @return the array in string form suitable for display. - */ - public static String print1DArray(Object aObject) { - if (aObject.getClass().isArray()) { - if (aObject instanceof Object[]) // can we cast to Object[] - return Arrays.toString((Object[]) aObject); - else { // we can't cast to Object[] - case of primitive arrays - int length = Array.getLength(aObject); - Object[] objArr = new Object[length]; - for (int i=0; i Object[] interleave(F first, S second) { Object[] retVal = new Object[(flen = Array.getLength(first)) + (slen = Array.getLength(second))]; for(int i = 0, j = 0, k = 0;i < flen || j < slen;) { if(i < flen) { - retVal[k++] = Array.get(first, i++); + retVal[k++] = getValue(first, i++); } if(j < slen) { - retVal[k++] = Array.get(second, j++); + retVal[k++] = getValue(second, j++); } } - + return retVal; } - + /** *

* Return a new double[] containing the difference of each element and its @@ -340,60 +340,6 @@ public static int[] unravel(int[][] array) { return result; } - /** - * Takes a two-dimensional array of r rows and c columns and reshapes it to - * have (r*c)/n by n columns. The value in location [i][j] of the input array - * is copied into location [j][i] of the new array. - * - * @param array The array of values to be reshaped. - * @param n The number of columns in the created array. - * @return The new (r*c)/n by n array. - * @throws IllegalArgumentException If r*c is not evenly divisible by n. - */ - public static int[][] reshape(int[][] array, int n) throws IllegalArgumentException { - int r = array.length; - if (r == 0) { - return new int[0][0]; // Special case: zero-length array - } - if ((array.length * array[0].length) % n != 0) { - int size = array.length * array[0].length; - throw new IllegalArgumentException(size + " is not evenly divisible by " + n); - } - int c = array[0].length; - int[][] result = new int[(r * c) / n][n]; - int ii = 0; - int jj = 0; - - for (int i = 0; i < r; i++) { - for (int j = 0; j < c; j++) { - result[ii][jj] = array[i][j]; - jj++; - if (jj == n) { - jj = 0; - ii++; - } - } - } - return result; - } - - /** - * Returns an int[] with the dimensions of the input. - * @param inputArray - * @return - */ - public static int[] shape(Object inputArray) { - int nr = 1 + inputArray.getClass().getName().lastIndexOf('['); - Object oa = inputArray; - int[] l = new int[nr]; - for(int i = 0;i < nr;i++) { - int len = l[i] = Array.getLength(oa); - if (0 < len) { oa = Array.get(oa, 0); } - } - - return l; - } - /** * Sorts the array, then returns an array containing the indexes of * those sorted items in the original array. @@ -1143,28 +1089,6 @@ public static int[] sparseBinaryOr(int[] arg1, int[] arg2) { return unique(t.toArray()); } - /** - * Prints the specified array to a returned String. - * - * @param aObject the array object to print. - * @return the array in string form suitable for display. - */ - public static String print1DArray(Object aObject) { - if (aObject.getClass().isArray()) { - if (aObject instanceof Object[]) // can we cast to Object[] - { - return Arrays.toString((Object[])aObject); - } else { // we can't cast to Object[] - case of primitive arrays - int length = Array.getLength(aObject); - Object[] objArr = new Object[length]; - for (int i = 0; i < length; i++) - objArr[i] = Array.get(aObject, i); - return Arrays.toString(objArr); - } - } - return "[]"; - } - /** * Another utility to account for the difference between Python and Java. * Here the modulo operator is defined differently. @@ -2122,28 +2046,95 @@ public static int[] tail(int[] original) { * @param indexes */ public static void setValue(Object array, int value, int... indexes) { - if (indexes.length == 1) { + if(indexes.length == 1) { ((int[])array)[indexes[0]] = value; } else { - setValue(Array.get(array, indexes[0]), value, tail(indexes)); + setValue(ArrayUtils.getSlice(array, indexes[0]), value, tail(indexes)); } } - + /** * Get value for array at specified position indexes * * @param array * @param indexes */ + public static int getIntValue(Object array, int... indexes) { + Object slice = array; + if(indexes.length > 1) { + for(int i = 0;i < indexes.length - 1;i++) { + slice = ((Object[])slice)[indexes[i]]; + } + } + return ((int[])slice)[indexes[indexes.length - 1]]; + } + public static Object getValue(Object array, int... indexes) { Object slice = array; - for(int i = 0;i < indexes.length;i++) { - slice = Array.get(slice, indexes[i]); + if(indexes.length > 1) { + for(int i = 0;i < indexes.length - 1;i++) { + slice = get(slice, i); + } + } + return get(slice, indexes[indexes.length - 1]); + } + + /** + * Get slice for array at specified position + * indexes + * + * @param array + * @param indexes + */ + public static Object getSlice(Object array, int... indexes) { + Object slice = array; + if(indexes.length > 1) { + for(int i = 0;i < indexes.length - 1;i++) { + slice = ((Object[])slice)[indexes[i]]; + } } - - return slice; + return ((Object[])slice)[indexes[indexes.length - 1]]; } + /** + * Gets an element of an array. Primitive elements will be wrapped in the + * corresponding class type. + * + * @param array + * the array to access + * @param index + * the array index to access + * @return the element at array[index] + * @throws IllegalArgumentException + * if array is not an array + * @throws NullPointerException + * if array is null + * @throws ArrayIndexOutOfBoundsException + * if index is out of bounds + */ + public static Object get(Object array, int index) { + if(array instanceof Object[]) + return ((Object[])array)[index]; + if(array instanceof boolean[]) + return ((boolean[])array)[index] ? Boolean.TRUE : Boolean.FALSE; + if(array instanceof byte[]) + return new Byte(((byte[])array)[index]); + if(array instanceof char[]) + return new Character(((char[])array)[index]); + if(array instanceof short[]) + return new Short(((short[])array)[index]); + if(array instanceof int[]) + return new Integer(((int[])array)[index]); + if(array instanceof long[]) + return new Long(((long[])array)[index]); + if(array instanceof float[]) + return new Float(((float[])array)[index]); + if(array instanceof double[]) + return new Double(((double[])array)[index]); + if(array == null) + throw new NullPointerException(); + throw new IllegalArgumentException(); + } /** *Assigns the specified int value to each element of the specified any dimensional array diff --git a/src/main/java/org/numenta/nupic/util/NamedTuple.java b/src/main/java/org/numenta/nupic/util/NamedTuple.java index cc8074ca..ac550434 100644 --- a/src/main/java/org/numenta/nupic/util/NamedTuple.java +++ b/src/main/java/org/numenta/nupic/util/NamedTuple.java @@ -22,13 +22,13 @@ package org.numenta.nupic.util; -import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; import org.numenta.nupic.model.Persistable; +import static org.numenta.nupic.util.ArrayUtils.interleave; /** * Immutable tuple which adds associative lookup functionality. @@ -408,28 +408,4 @@ public boolean equals(Object obj) { } } - /** - * Returns an array containing the successive elements of each - * argument array as in [ first[0], second[0], first[1], second[1], ... ]. - * - * Arrays may be of zero length, and may be of different sizes, but may not be null. - * - * @param first the first array - * @param second the second array - * @return - */ - static Object[] interleave(F first, S second) { - int flen, slen; - Object[] retVal = new Object[(flen = Array.getLength(first)) + (slen = Array.getLength(second))]; - for(int i = 0, j = 0, k = 0;i < flen || j < slen;) { - if(i < flen) { - retVal[k++] = Array.get(first, i++); - } - if(j < slen) { - retVal[k++] = Array.get(second, j++); - } - } - - return retVal; - } } diff --git a/src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java b/src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java index 69effcd2..40c1b462 100644 --- a/src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java +++ b/src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java @@ -83,18 +83,18 @@ private void back(int val, int... coordinates) { * * @param coordinates the coordinates which specify the returned array * @return the array specified - * @throws IllegalArgumentException if the specified coordinates address - * an actual value instead of the array holding it. + * @throws IllegalArgumentException if the specified coordinates address + * an actual value instead of the array holding it. */ @Override public Object getSlice(int... coordinates) { - Object slice = ArrayUtils.getValue(this.backingArray, coordinates); - //Ensure return value is of type Array - if(!slice.getClass().isArray()) { - sliceError(coordinates); + try { + return ArrayUtils.getSlice(this.backingArray, coordinates); + } catch(ClassCastException e) { + throw new IllegalArgumentException( + "This method only returns the array holding the specified maximum index: " + + Arrays.toString(dimensions)); } - - return slice; } /** @@ -177,7 +177,7 @@ public AbstractSparseBinaryMatrix set(int[] indexes, int[] values) { */ public void clearStatistics(int row) { this.setTrueCount(row, 0); - int[] slice = (int[])Array.get(backingArray, row); + int[] slice = (int[])ArrayUtils.getSlice(backingArray, row); Arrays.fill(slice, 0); } @@ -194,7 +194,7 @@ public Integer get(int index) { return Array.getInt(this.backingArray, index); } - else return (Integer) ArrayUtils.getValue(this.backingArray, coordinates); + else return (Integer) ArrayUtils.getIntValue(this.backingArray, coordinates); } @Override diff --git a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java index 0195999c..5ef59ec9 100644 --- a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java +++ b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java @@ -34,12 +34,112 @@ import java.util.Arrays; import java.util.List; +import org.junit.Assert; import org.junit.Test; import org.numenta.nupic.model.Cell; import org.numenta.nupic.model.Column; public class ArrayUtilsTest { + @Test + public void testJavaArrayBehavior() { + Object array = Array.newInstance(int.class, 1); + Assert.assertTrue(array instanceof int[]); + array = Array.newInstance(int.class, 2); + Assert.assertTrue(array instanceof int[]); + Assert.assertFalse(array instanceof int[][]); + array = Array.newInstance(int.class, 3); + Assert.assertTrue(array instanceof int[]); + Assert.assertFalse(array instanceof int[][]); + Assert.assertFalse(array instanceof int[][][]); + } + + @Test + public void testGet() { + int[] singleDimArray = new int[2]; + singleDimArray[0]=0; + singleDimArray[1]=1; + assertEquals(0, ArrayUtils.getIntValue(singleDimArray, 0)); + assertEquals(1, ArrayUtils.getIntValue(singleDimArray, 1)); + + int[][] twoDimArray = new int[2][2]; + twoDimArray[0]=new int[2]; + twoDimArray[1]=new int[2]; + twoDimArray[0][0]=0; + twoDimArray[0][1]=1; + twoDimArray[1][0]=2; + twoDimArray[1][1]=3; + assertEquals(0, ArrayUtils.getIntValue(twoDimArray, 0, 0)); + assertEquals(1, ArrayUtils.getIntValue(twoDimArray, 0, 1)); + assertEquals(2, ArrayUtils.getIntValue(twoDimArray, 1, 0)); + assertEquals(3, ArrayUtils.getIntValue(twoDimArray, 1, 1)); + } + + @Test + public void testSetAndGet() { + int[] dimensions = new int[]{2, 2, 2}; + int[][][] threeDimArray = (int[][][]) Array.newInstance(int.class, dimensions); + System.out.println("input array:"); + for(int i = 0; i < 8; i++) { + int[] coords = computeCoordinates(i, dimensions); + System.out.println(Arrays.toString(coords)); + ArrayUtils.setValue(threeDimArray, i, coords); + } + for(int i = 0; i < 8; i++) { + int[] coords = computeCoordinates(i, dimensions); + assertEquals(i, ArrayUtils.getIntValue(threeDimArray, coords)); + } + + int[][] twoDimArray = (int[][]) ArrayUtils.getSlice(threeDimArray, new int[]{0}); + for(int i = 0; i < 4; i++) { + int[] coords = computeCoordinates(i, new int[]{2,2}); + assertEquals(i, ArrayUtils.getIntValue(twoDimArray, coords)); + } + int[] oneDimArray = (int[]) ArrayUtils.getSlice(threeDimArray, new int[]{0,0}); + assertEquals(0, oneDimArray[0]); + assertEquals(1, oneDimArray[1]); + oneDimArray = (int[]) ArrayUtils.getSlice(threeDimArray, new int[]{0,1}); + assertEquals(2, oneDimArray[0]); + assertEquals(3, oneDimArray[1]); + oneDimArray = (int[]) ArrayUtils.getSlice(threeDimArray, new int[]{1,0}); + assertEquals(4, oneDimArray[0]); + assertEquals(5, oneDimArray[1]); + oneDimArray = (int[]) ArrayUtils.getSlice(threeDimArray, new int[]{1,1}); + assertEquals(6, oneDimArray[0]); + assertEquals(7, oneDimArray[1]); + } + + private static int[] computeCoordinates(int index, int[] dimensions) { + int[] returnVal = new int[dimensions.length]; + int[] dimensionMultiples = initDimensionMultiples(dimensions); + int base = index; + for(int i = 0;i < dimensionMultiples.length; i++) { + int quotient = base / dimensionMultiples[i]; + base %= dimensionMultiples[i]; + returnVal[i] = quotient; + } + return returnVal; + } + + private static int[] initDimensionMultiples(int[] dimensions) { + int holder = 1; + int len = dimensions.length; + int[] dimensionMultiples = new int[dimensions.length]; + for(int i = 0;i < len;i++) { + holder *= (i == 0 ? 1 : dimensions[len - i]); + dimensionMultiples[len - 1 - i] = holder; + } + return dimensionMultiples; + } + + public static int[] reverse(int[] input) { + int[] retVal = new int[input.length]; + for(int i = input.length - 1, j = 0;i >= 0;i--, j++) { + retVal[j] = input[i]; + } + return retVal; + } + @Test public void testToBytes() { boolean[] ba = { true, true, }; @@ -115,49 +215,7 @@ public void testArgsort() { args = ArrayUtils.argsort(new int[] { 11, 2, 3, 7, 0 }, 0, 3); assertTrue(Arrays.equals(new int[] {4, 1, 2}, args)); } - - @Test - public void testShape() { - int[][] inputPattern = { { 2, 3, 4, 5 }, { 6, 7, 8, 9} }; - int[] shape = ArrayUtils.shape(inputPattern); - assertTrue(Arrays.equals(new int[] { 2, 4 }, shape)); - } - - @Test - public void testReshape() { - int[][] test = { - { 0, 1, 2, 3, 4, 5 }, - { 6, 7, 8, 9, 10, 11 } - }; - - int[][] expected = { - { 0, 1, 2 }, - { 3, 4, 5 }, - { 6, 7, 8 }, - { 9, 10, 11 } - }; - - int[][] result = ArrayUtils.reshape(test, 3); - for(int i = 0;i < result.length;i++) { - for(int j = 0;j < result[i].length;j++) { - assertEquals(expected[i][j], result[i][j]); - } - } - - // Unhappy case - try { - ArrayUtils.reshape(test, 5); - }catch(Exception e) { - assertTrue(e instanceof IllegalArgumentException); - assertEquals("12 is not evenly divisible by 5", e.getMessage()); - } - - // Test zero-length case - int[] result4 = ArrayUtils.unravel(new int[0][]); - assertNotNull(result4); - assertTrue(result4.length == 0); - } - + @Test public void testRavelAndUnRavel() { int[] test = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 }; diff --git a/src/test/java/org/numenta/nupic/util/MatrixTest.java b/src/test/java/org/numenta/nupic/util/MatrixTest.java index b6982c49..5e2a0d67 100644 --- a/src/test/java/org/numenta/nupic/util/MatrixTest.java +++ b/src/test/java/org/numenta/nupic/util/MatrixTest.java @@ -49,7 +49,7 @@ public void testBitSetMatrixSet() { } assertArrayEquals(expected, asDense(bsm)); - assertEquals(Arrays.toString(expected), FlatArrayMatrix.print1DArray(asDense(bsm))); + assertEquals(Arrays.toString(expected), Arrays.toString(asDense(bsm))); } @Test @@ -63,7 +63,7 @@ public void testFlatArrayMatrixSet() { } assertArrayEquals(expected, asDense(fam)); - assertEquals(Arrays.toString(expected), FlatArrayMatrix.print1DArray(asDense(fam))); + assertEquals(Arrays.toString(expected), Arrays.toString(asDense(fam))); } private Object[] asDense(FlatMatrix matrix) { diff --git a/src/test/java/org/numenta/nupic/util/NamedTupleTest.java b/src/test/java/org/numenta/nupic/util/NamedTupleTest.java index 77089ca8..8753daa8 100644 --- a/src/test/java/org/numenta/nupic/util/NamedTupleTest.java +++ b/src/test/java/org/numenta/nupic/util/NamedTupleTest.java @@ -133,87 +133,6 @@ public void testEquality() { assertNotEquals(nt, nt2); } - @Test - public void testInterleave() { - String[] f = { "0" }; - double[] s = { 0.8 }; - - // Test most simple interleave of equal length arrays - Object[] result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals(0.8, result[1]); - - // Test simple interleave of larger array - f = new String[] { "0", "1" }; - s = new double[] { 0.42, 2.5 }; - result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals(0.42, result[1]); - assertEquals("1", result[2]); - assertEquals(2.5, result[3]); - - // Test complex interleave of larger array - f = new String[] { "0", "1", "bob", "harry", "digit", "temperature" }; - s = new double[] { 0.42, 2.5, .001, 1e-2, 34.0, .123 }; - result = NamedTuple.interleave(f, s); - for(int i = 0, j = 0;j < result.length;i++, j+=2) { - assertEquals(f[i], result[j]); - assertEquals(s[i], result[j + 1]); - } - - // Test interleave with zero length of first - f = new String[0]; - s = new double[] { 0.42, 2.5 }; - result = NamedTuple.interleave(f, s); - assertEquals(0.42, result[0]); - assertEquals(2.5, result[1]); - - // Test interleave with zero length of second - f = new String[] { "0", "1" }; - s = new double[0]; - result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals("1", result[1]); - - // Test complex unequal length: left side smaller - f = new String[] { "0", "1", "bob" }; - s = new double[] { 0.42, 2.5, .001, 1e-2, 34.0, .123 }; - result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals(0.42, result[1]); - assertEquals("1", result[2]); - assertEquals(2.5, result[3]); - assertEquals("bob", result[4]); - assertEquals(.001, result[5]); - assertEquals(1e-2, result[6]); - assertEquals(34.0, result[7]); - assertEquals(.123, result[8]); - - // Test complex unequal length: right side smaller - f = new String[] { "0", "1", "bob", "harry", "digit", "temperature" }; - s = new double[] { 0.42, 2.5, .001 }; - result = NamedTuple.interleave(f, s); - assertEquals("0", result[0]); - assertEquals(0.42, result[1]); - assertEquals("1", result[2]); - assertEquals(2.5, result[3]); - assertEquals("bob", result[4]); - assertEquals(.001, result[5]); - assertEquals("harry", result[6]); - assertEquals("digit", result[7]); - assertEquals("temperature", result[8]); - - // Negative testing - try { - f = null; - s = new double[] { 0.42, 2.5, .001 }; - result = NamedTuple.interleave(f, s); - fail(); - }catch(Exception e) { - assertEquals(NullPointerException.class, e.getClass()); - } - } - @Test public void testGetValues() { Set set = new LinkedHashSet<>(); diff --git a/src/test/java/org/numenta/nupic/util/SparseObjectMatrixTest.java b/src/test/java/org/numenta/nupic/util/SparseObjectMatrixTest.java index 882144f7..88867f5b 100644 --- a/src/test/java/org/numenta/nupic/util/SparseObjectMatrixTest.java +++ b/src/test/java/org/numenta/nupic/util/SparseObjectMatrixTest.java @@ -25,6 +25,8 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import java.util.Arrays; + import org.junit.Test; public class SparseObjectMatrixTest { @@ -38,7 +40,7 @@ public void testGetDimensionMultiples() { sm = new SparseObjectMatrix(new int[] { 1, 2, 3, 4, 5 }); dm = sm.getDimensionMultiples(); - assertEquals(ArrayUtils.print1DArray(dm), "[120, 60, 20, 5, 1]"); + assertEquals(Arrays.toString(dm), "[120, 60, 20, 5, 1]"); } /**