Skip to content

Commit

Permalink
Add functions to make nd arrays into lists of single arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Oct 11, 2023
1 parent 60b2fa9 commit d707496
Showing 1 changed file with 144 additions and 0 deletions.
144 changes: 144 additions & 0 deletions stingray/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numbers
import os
import re
import copy
import random
import string
Expand Down Expand Up @@ -140,6 +141,149 @@ def mad(data, c=0.6745, axis=None):
]


@njit
def any_complex_in_array(array):
"""Check if any element of an array is complex.
Examples
--------
>>> any_complex_in_array(np.array([1, 2, 3]))
False
>>> any_complex_in_array(np.array([1, 2 + 1.j, 3]))
True
"""
for a in array:
if np.iscomplex(a):
return True
return False


def make_nd_into_arrays(array: np.ndarray, label: str) -> dict:
"""If an array is n-dimensional, make it into many 1-dimensional arrays.
Parameters
----------
array : `np.ndarray`
Input data
label : `str`
Label for the array
Returns
-------
data : `dict`
Dictionary of arrays. Defauls to ``{label: array}`` if ``array`` is 1-dimensional,
otherwise, e.g.: ``{label_dim1_2_3: array[1, 2, 3], ... }``
Examples
--------
>>> a1, a2, a3 = np.arange(3), np.arange(3, 6), np.arange(6, 9)
>>> A = np.array([a1, a2, a3]).T
>>> data = make_nd_into_arrays(A, "test")
>>> np.array_equal(data["test_dim0"], a1)
True
>>> np.array_equal(data["test_dim1"], a2)
True
>>> np.array_equal(data["test_dim2"], a3)
True
>>> A3 = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
>>> data = make_nd_into_arrays(A3, "test")
>>> np.array_equal(data["test_dim0_0"], [1, 5])
True
"""
data = {}
array = np.asarray(array)
shape = np.shape(array)
ndim = len(shape)
if ndim <= 1:
data[label] = array
else:
for i in range(shape[1]):
new_label = f"_dim{i}" if "_dim" not in label else f"_{i}"
dumdata = make_nd_into_arrays(array[:, i], label=label + new_label)
data.update(dumdata)
return data


def get_dimensions_from_list_of_column_labels(labels: list, label: str) -> list:
"""Get the dimensions of a multi-dimensional array from a list of column labels.
Examples
--------
>>> labels = ['test_dim0_0', 'test_dim0_1', 'test_dim0_2',
... 'test_dim1_0', 'test_dim1_1', 'test_dim1_2', 'test', 'bu']
>>> keys, dimensions = get_dimensions_from_list_of_column_labels(labels, "test")
>>> for key0, key1 in zip(labels[:6], keys): assert key0 == key1
>>> np.array_equal(dimensions, [2, 3])
True
"""
all_keys = []
count_dimensions = None
for key in labels:
if label not in key:
continue
match = re.search(label + r"_dim([0-9]+(_[0-9]+)*)", key)
if match is None:
continue
all_keys.append(key)
new_count_dimensions = [int(val) for val in match.groups()[0].split("_")]
if count_dimensions is None:
count_dimensions = np.array(new_count_dimensions)
else:
count_dimensions = np.max([count_dimensions, new_count_dimensions], axis=0)

return sorted(all_keys), count_dimensions + 1


def make_1d_arrays_into_nd(data: dict, label: str) -> np.ndarray:
"""Literally the opposite of make_nd_into_arrays.
Parameters
----------
data : dict
Input data
label : `str`
Label for the array
Returns
-------
array : `np.array`
N-dimensional array that was stored in the data.
Examples
--------
>>> a1, a2, a3 = np.arange(3), np.arange(3, 6), np.arange(6, 9)
>>> A = np.array([a1, a2, a3]).T
>>> data = make_nd_into_arrays(A, "test")
>>> A_ret = make_1d_arrays_into_nd(data, "test")
>>> np.array_equal(A, A_ret)
True
>>> A = np.array([[[1, 2, 12], [3, 4, 34]],
... [[5, 6, 56], [7, 8, 78]],
... [[9, 10, 910], [11, 12, 1112]],
... [[13, 14, 1314], [15, 16, 1516]]])
>>> data = make_nd_into_arrays(A, "test")
>>> A_ret = make_1d_arrays_into_nd(data, "test")
>>> np.array_equal(A, A_ret)
True
>>> data = make_nd_into_arrays(a1, "test")
>>> A_ret = make_1d_arrays_into_nd(data, "test")
>>> np.array_equal(a1, A_ret)
True
"""

if label in list(data.keys()):
return data[label]

# Get the dimensionality of the data
dim = 0
all_keys = []

all_keys, dimensions = get_dimensions_from_list_of_column_labels(list(data.keys()), label)
arrays = np.array([np.array(data[key]) for key in all_keys])

return arrays.T.reshape([len(arrays[0])] + list(dimensions))


@njit
def _check_isallfinite_numba(array):
"""Check if all elements of an array are finite.
Expand Down

0 comments on commit d707496

Please sign in to comment.