Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance optimization of multidimensional array access #535

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 1 addition & 23 deletions src/main/java/org/numenta/nupic/util/AbstractFlatMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
package org.numenta.nupic.util;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;

/**
Expand Down Expand Up @@ -112,7 +111,7 @@ protected void checkDims(int[] index) {
for(int i = 0;i < index.length - 1;i++) {
if(index[i] >= dimensions[i]) {
throw new IllegalArgumentException("Specified coordinates exceed the configured array dimensions " +
print1DArray(index) + " > " + print1DArray(dimensions));
Arrays.toString(index) + " > " + Arrays.toString(dimensions));
}
}
}
Expand Down Expand Up @@ -179,27 +178,6 @@ public static int[] reverse(int[] input) {
return retVal;
}

/**
* Prints the specified array to a returned String.
*
* @param aObject the array object to print.
* @return the array in string form suitable for display.
*/
public static String print1DArray(Object aObject) {
if (aObject.getClass().isArray()) {
if (aObject instanceof Object[]) // can we cast to Object[]
return Arrays.toString((Object[]) aObject);
else { // we can't cast to Object[] - case of primitive arrays
int length = Array.getLength(aObject);
Object[] objArr = new Object[length];
for (int i=0; i<length; i++)
objArr[i] = Array.get(aObject, i);
return Arrays.toString(objArr);
}
}
return "[]";
}

@Override
public abstract T get(int index);

Expand Down
167 changes: 79 additions & 88 deletions src/main/java/org/numenta/nupic/util/ArrayUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public static int product(int[] dims) {

return retVal;
}

/**
* Returns an array containing the successive elements of each
* argument array as in [ first[0], second[0], first[1], second[1], ... ].
Expand All @@ -100,16 +100,16 @@ public static <F, S> Object[] interleave(F first, S second) {
Object[] retVal = new Object[(flen = Array.getLength(first)) + (slen = Array.getLength(second))];
for(int i = 0, j = 0, k = 0;i < flen || j < slen;) {
if(i < flen) {
retVal[k++] = Array.get(first, i++);
retVal[k++] = getValue(first, i++);
}
if(j < slen) {
retVal[k++] = Array.get(second, j++);
retVal[k++] = getValue(second, j++);
}
}

return retVal;
}

/**
* <p>
* Return a new double[] containing the difference of each element and its
Expand Down Expand Up @@ -340,60 +340,6 @@ public static int[] unravel(int[][] array) {
return result;
}

/**
* Takes a two-dimensional array of r rows and c columns and reshapes it to
* have (r*c)/n by n columns. The value in location [i][j] of the input array
* is copied into location [j][i] of the new array.
*
* @param array The array of values to be reshaped.
* @param n The number of columns in the created array.
* @return The new (r*c)/n by n array.
* @throws IllegalArgumentException If r*c is not evenly divisible by n.
*/
public static int[][] reshape(int[][] array, int n) throws IllegalArgumentException {
int r = array.length;
if (r == 0) {
return new int[0][0]; // Special case: zero-length array
}
if ((array.length * array[0].length) % n != 0) {
int size = array.length * array[0].length;
throw new IllegalArgumentException(size + " is not evenly divisible by " + n);
}
int c = array[0].length;
int[][] result = new int[(r * c) / n][n];
int ii = 0;
int jj = 0;

for (int i = 0; i < r; i++) {
for (int j = 0; j < c; j++) {
result[ii][jj] = array[i][j];
jj++;
if (jj == n) {
jj = 0;
ii++;
}
}
}
return result;
}

/**
* Returns an int[] with the dimensions of the input.
* @param inputArray
* @return
*/
public static int[] shape(Object inputArray) {
int nr = 1 + inputArray.getClass().getName().lastIndexOf('[');
Object oa = inputArray;
int[] l = new int[nr];
for(int i = 0;i < nr;i++) {
int len = l[i] = Array.getLength(oa);
if (0 < len) { oa = Array.get(oa, 0); }
}

return l;
}

/**
* Sorts the array, then returns an array containing the indexes of
* those sorted items in the original array.
Expand Down Expand Up @@ -1143,28 +1089,6 @@ public static int[] sparseBinaryOr(int[] arg1, int[] arg2) {
return unique(t.toArray());
}

/**
* Prints the specified array to a returned String.
*
* @param aObject the array object to print.
* @return the array in string form suitable for display.
*/
public static String print1DArray(Object aObject) {
if (aObject.getClass().isArray()) {
if (aObject instanceof Object[]) // can we cast to Object[]
{
return Arrays.toString((Object[])aObject);
} else { // we can't cast to Object[] - case of primitive arrays
int length = Array.getLength(aObject);
Object[] objArr = new Object[length];
for (int i = 0; i < length; i++)
objArr[i] = Array.get(aObject, i);
return Arrays.toString(objArr);
}
}
return "[]";
}

/**
* Another utility to account for the difference between Python and Java.
* Here the modulo operator is defined differently.
Expand Down Expand Up @@ -2122,28 +2046,95 @@ public static int[] tail(int[] original) {
* @param indexes
*/
public static void setValue(Object array, int value, int... indexes) {
if (indexes.length == 1) {
if(indexes.length == 1) {
((int[])array)[indexes[0]] = value;
} else {
setValue(Array.get(array, indexes[0]), value, tail(indexes));
setValue(ArrayUtils.getSlice(array, indexes[0]), value, tail(indexes));
}
}

/**
* Get <tt>value</tt> for <tt>array</tt> at specified position <tt>indexes</tt>
*
* @param array
* @param indexes
*/
public static int getIntValue(Object array, int... indexes) {
Object slice = array;
if(indexes.length > 1) {
for(int i = 0;i < indexes.length - 1;i++) {
slice = ((Object[])slice)[indexes[i]];
}
}
return ((int[])slice)[indexes[indexes.length - 1]];
}

public static Object getValue(Object array, int... indexes) {
Object slice = array;
for(int i = 0;i < indexes.length;i++) {
slice = Array.get(slice, indexes[i]);
if(indexes.length > 1) {
for(int i = 0;i < indexes.length - 1;i++) {
slice = get(slice, i);
}
}
return get(slice, indexes[indexes.length - 1]);
}

/**
* Get <tt>slice</tt> for <tt>array</tt> at specified position
* <tt>indexes</tt>
*
* @param array
* @param indexes
*/
public static Object getSlice(Object array, int... indexes) {
Object slice = array;
if(indexes.length > 1) {
for(int i = 0;i < indexes.length - 1;i++) {
slice = ((Object[])slice)[indexes[i]];
}
}

return slice;
return ((Object[])slice)[indexes[indexes.length - 1]];
}

/**
* Gets an element of an array. Primitive elements will be wrapped in the
* corresponding class type.
*
* @param array
* the array to access
* @param index
* the array index to access
* @return the element at <code>array[index]</code>
* @throws IllegalArgumentException
* if <code>array</code> is not an array
* @throws NullPointerException
* if <code>array</code> is null
* @throws ArrayIndexOutOfBoundsException
* if <code>index</code> is out of bounds
*/
public static Object get(Object array, int index) {
if(array instanceof Object[])
return ((Object[])array)[index];
if(array instanceof boolean[])
return ((boolean[])array)[index] ? Boolean.TRUE : Boolean.FALSE;
if(array instanceof byte[])
return new Byte(((byte[])array)[index]);
if(array instanceof char[])
return new Character(((char[])array)[index]);
if(array instanceof short[])
return new Short(((short[])array)[index]);
if(array instanceof int[])
return new Integer(((int[])array)[index]);
if(array instanceof long[])
return new Long(((long[])array)[index]);
if(array instanceof float[])
return new Float(((float[])array)[index]);
if(array instanceof double[])
return new Double(((double[])array)[index]);
if(array == null)
throw new NullPointerException();
throw new IllegalArgumentException();
}

/**
*Assigns the specified int value to each element of the specified any dimensional array
Expand Down
26 changes: 1 addition & 25 deletions src/main/java/org/numenta/nupic/util/NamedTuple.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@

package org.numenta.nupic.util;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

import org.numenta.nupic.model.Persistable;
import static org.numenta.nupic.util.ArrayUtils.interleave;

/**
* Immutable tuple which adds associative lookup functionality.
Expand Down Expand Up @@ -408,28 +408,4 @@ public boolean equals(Object obj) {
}
}

/**
* Returns an array containing the successive elements of each
* argument array as in [ first[0], second[0], first[1], second[1], ... ].
*
* Arrays may be of zero length, and may be of different sizes, but may not be null.
*
* @param first the first array
* @param second the second array
* @return
*/
static <F, S> Object[] interleave(F first, S second) {
int flen, slen;
Object[] retVal = new Object[(flen = Array.getLength(first)) + (slen = Array.getLength(second))];
for(int i = 0, j = 0, k = 0;i < flen || j < slen;) {
if(i < flen) {
retVal[k++] = Array.get(first, i++);
}
if(j < slen) {
retVal[k++] = Array.get(second, j++);
}
}

return retVal;
}
}
20 changes: 10 additions & 10 deletions src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,18 @@ private void back(int val, int... coordinates) {
*
* @param coordinates the coordinates which specify the returned array
* @return the array specified
* @throws IllegalArgumentException if the specified coordinates address
* an actual value instead of the array holding it.
* @throws IllegalArgumentException if the specified coordinates address
* an actual value instead of the array holding it.
*/
@Override
public Object getSlice(int... coordinates) {
Object slice = ArrayUtils.getValue(this.backingArray, coordinates);
//Ensure return value is of type Array
if(!slice.getClass().isArray()) {
sliceError(coordinates);
try {
return ArrayUtils.getSlice(this.backingArray, coordinates);
} catch(ClassCastException e) {
throw new IllegalArgumentException(
"This method only returns the array holding the specified maximum index: " +
Arrays.toString(dimensions));
}

return slice;
}

/**
Expand Down Expand Up @@ -177,7 +177,7 @@ public AbstractSparseBinaryMatrix set(int[] indexes, int[] values) {
*/
public void clearStatistics(int row) {
this.setTrueCount(row, 0);
int[] slice = (int[])Array.get(backingArray, row);
int[] slice = (int[])ArrayUtils.getSlice(backingArray, row);
Arrays.fill(slice, 0);
}

Expand All @@ -194,7 +194,7 @@ public Integer get(int index) {
return Array.getInt(this.backingArray, index);
}

else return (Integer) ArrayUtils.getValue(this.backingArray, coordinates);
else return (Integer) ArrayUtils.getIntValue(this.backingArray, coordinates);
}

@Override
Expand Down
Loading