From a09aa6d1a29dfa38b171b56c055740d8f2d29bf8 Mon Sep 17 00:00:00 2001 From: "Toni M. Brotons" <10654467+toni-neurosc@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:26:01 +0200 Subject: [PATCH] Fix handling of dataframe input, issue #300 (#301) --- py_neuromodulation/nm_stream_offline.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/py_neuromodulation/nm_stream_offline.py b/py_neuromodulation/nm_stream_offline.py index 7df72243..708dcbb8 100644 --- a/py_neuromodulation/nm_stream_offline.py +++ b/py_neuromodulation/nm_stream_offline.py @@ -88,9 +88,9 @@ def _handle_data(self, data: np.ndarray | pd.DataFrame) -> np.ndarray: if not len(names_expected) == data.shape[0]: raise ValueError( "If data is passed as an array, the first dimension must" - " match the number of channel names in `nm_channels`. Got:" - f" Data columns: {data.shape[0]}, nm_channels.name:" - f" {len(names_expected)}." + " match the number of channel names in `nm_channels`.\n" + f" Number of data channels (data.shape[0]): {data.shape[0]}\n" + f" Length of nm_channels[\"name\"]: {len(names_expected)}." ) return data names_data = data.columns.to_list() @@ -100,10 +100,11 @@ def _handle_data(self, data: np.ndarray | pd.DataFrame) -> np.ndarray: ): raise ValueError( "If data is passed as a DataFrame, the" - "columns must match the channel names in `nm_channels`. Got:" - f"Data columns: {names_data}, nm_channels.name: {names_data}." + "column names must match the channel names in `nm_channels`.\n" + f"Input dataframe column names: {names_data}\n" + f"Expected (from nm_channels[\"name\"]): : {names_expected}." ) - return data.to_numpy() + return data.to_numpy().transpose() def _check_settings_for_parallel(self): """Check specified settings and raise error if parallel processing is not possible.