diff --git a/stingray/utils.py b/stingray/utils.py index f9ba9a221..793520944 100644 --- a/stingray/utils.py +++ b/stingray/utils.py @@ -1,5 +1,6 @@ import numbers import os +import re import copy import random import string @@ -140,6 +141,149 @@ def mad(data, c=0.6745, axis=None): ] +@njit +def any_complex_in_array(array): + """Check if any element of an array is complex. + + Examples + -------- + >>> any_complex_in_array(np.array([1, 2, 3])) + False + >>> any_complex_in_array(np.array([1, 2 + 1.j, 3])) + True + """ + for a in array: + if np.iscomplex(a): + return True + return False + + +def make_nd_into_arrays(array: np.ndarray, label: str) -> dict: + """If an array is n-dimensional, make it into many 1-dimensional arrays. + + Parameters + ---------- + array : `np.ndarray` + Input data + label : `str` + Label for the array + + Returns + ------- + data : `dict` + Dictionary of arrays. Defauls to ``{label: array}`` if ``array`` is 1-dimensional, + otherwise, e.g.: ``{label_dim1_2_3: array[1, 2, 3], ... }`` + + Examples + -------- + >>> a1, a2, a3 = np.arange(3), np.arange(3, 6), np.arange(6, 9) + >>> A = np.array([a1, a2, a3]).T + >>> data = make_nd_into_arrays(A, "test") + >>> np.array_equal(data["test_dim0"], a1) + True + >>> np.array_equal(data["test_dim1"], a2) + True + >>> np.array_equal(data["test_dim2"], a3) + True + >>> A3 = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> data = make_nd_into_arrays(A3, "test") + >>> np.array_equal(data["test_dim0_0"], [1, 5]) + True + """ + data = {} + array = np.asarray(array) + shape = np.shape(array) + ndim = len(shape) + if ndim <= 1: + data[label] = array + else: + for i in range(shape[1]): + new_label = f"_dim{i}" if "_dim" not in label else f"_{i}" + dumdata = make_nd_into_arrays(array[:, i], label=label + new_label) + data.update(dumdata) + return data + + +def get_dimensions_from_list_of_column_labels(labels: list, label: str) -> list: + """Get the dimensions of a multi-dimensional array from a list of column labels. + + Examples + -------- + >>> labels = ['test_dim0_0', 'test_dim0_1', 'test_dim0_2', + ... 'test_dim1_0', 'test_dim1_1', 'test_dim1_2', 'test', 'bu'] + >>> keys, dimensions = get_dimensions_from_list_of_column_labels(labels, "test") + >>> for key0, key1 in zip(labels[:6], keys): assert key0 == key1 + >>> np.array_equal(dimensions, [2, 3]) + True + """ + all_keys = [] + count_dimensions = None + for key in labels: + if label not in key: + continue + match = re.search(label + r"_dim([0-9]+(_[0-9]+)*)", key) + if match is None: + continue + all_keys.append(key) + new_count_dimensions = [int(val) for val in match.groups()[0].split("_")] + if count_dimensions is None: + count_dimensions = np.array(new_count_dimensions) + else: + count_dimensions = np.max([count_dimensions, new_count_dimensions], axis=0) + + return sorted(all_keys), count_dimensions + 1 + + +def make_1d_arrays_into_nd(data: dict, label: str) -> np.ndarray: + """Literally the opposite of make_nd_into_arrays. + + Parameters + ---------- + data : dict + Input data + label : `str` + Label for the array + + Returns + ------- + array : `np.array` + N-dimensional array that was stored in the data. + + Examples + -------- + >>> a1, a2, a3 = np.arange(3), np.arange(3, 6), np.arange(6, 9) + >>> A = np.array([a1, a2, a3]).T + >>> data = make_nd_into_arrays(A, "test") + >>> A_ret = make_1d_arrays_into_nd(data, "test") + >>> np.array_equal(A, A_ret) + True + >>> A = np.array([[[1, 2, 12], [3, 4, 34]], + ... [[5, 6, 56], [7, 8, 78]], + ... [[9, 10, 910], [11, 12, 1112]], + ... [[13, 14, 1314], [15, 16, 1516]]]) + >>> data = make_nd_into_arrays(A, "test") + >>> A_ret = make_1d_arrays_into_nd(data, "test") + >>> np.array_equal(A, A_ret) + True + >>> data = make_nd_into_arrays(a1, "test") + >>> A_ret = make_1d_arrays_into_nd(data, "test") + >>> np.array_equal(a1, A_ret) + True + """ + + if label in list(data.keys()): + return data[label] + + # Get the dimensionality of the data + dim = 0 + all_keys = [] + + all_keys, dimensions = get_dimensions_from_list_of_column_labels(list(data.keys()), label) + arrays = np.array([np.array(data[key]) for key in all_keys]) + + return arrays.T.reshape([len(arrays[0])] + list(dimensions)) + + @njit def _check_isallfinite_numba(array): """Check if all elements of an array are finite.