diff --git a/tests/test_preprocessing/test_pandas_feature_selector.py b/tests/test_preprocessing/test_pandas_feature_selector.py index d6195e0..3504ae2 100644 --- a/tests/test_preprocessing/test_pandas_feature_selector.py +++ b/tests/test_preprocessing/test_pandas_feature_selector.py @@ -5,13 +5,13 @@ import timeserio.ini as ini from timeserio.data.mock import mock_fit_data from timeserio.preprocessing import ( - PandasColumnSelector, PandasValueSelector, - PandasIndexValueSelector, PandasSequenceSplitter + PandasColumnSelector, PandasValueSelector, PandasIndexValueSelector, + PandasSequenceSplitter ) - datetime_column = ini.Columns.datetime usage_column = ini.Columns.target +id_column = ini.Columns.id @pytest.fixture @@ -66,6 +66,12 @@ def test_value_selector(df, columns, shape1): assert subarray.shape == expected_shape +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +def test_value_selector_dtype(df, dtype): + subarray = PandasValueSelector(columns="id", dtype=dtype).transform(df) + assert subarray.dtype == dtype + + @pytest.mark.parametrize( 'levels, shape1', [ (None, 0), @@ -83,6 +89,13 @@ def test_index_value_selector(indexed_df, levels, shape1): assert subarray.shape == expected_shape +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +def test_index_value_selector_dtype(indexed_df, dtype): + subarray = PandasIndexValueSelector(levels="id", + dtype=dtype).transform(indexed_df) + assert subarray.dtype == dtype + + @pytest.mark.parametrize( 'transformer, required_columns', [ (PandasColumnSelector('col1'), {'col1'}), diff --git a/timeserio/preprocessing/pandas.py b/timeserio/preprocessing/pandas.py index e9fd786..f7eab49 100644 --- a/timeserio/preprocessing/pandas.py +++ b/timeserio/preprocessing/pandas.py @@ -82,10 +82,14 @@ def _get_column_as_tensor(s: pd.Series): class PandasValueSelector(BaseEstimator, TransformerMixin): - """Select scalar - or vector-valued feature cols, and return np.array.""" + """Select scalar - or vector-valued feature cols, and return np.array. - def __init__(self, columns=None): + Optionally, cast the resulting arry to dtype. + """ + + def __init__(self, columns=None, dtype=None): self.columns = columns + self.dtype = dtype def fit(self, df, y=None, **fit_params): return self @@ -98,6 +102,8 @@ def transform(self, df): else: # support a mix of compatible tensors and regular columns blocks = [_get_column_as_tensor(df[col]) for col in columns] subarray = np.hstack(blocks) + if self.dtype: + subarray = subarray.astype(self.dtype) return subarray @property @@ -112,10 +118,14 @@ def transformed_columns(self, input_columns): class PandasIndexValueSelector(BaseEstimator, TransformerMixin): - """Select index levels as feature cols, and return np.array.""" + """Select index levels as feature cols, and return np.array. + + Optionally, cast the resulting arry to dtype. + """ - def __init__(self, levels=None): + def __init__(self, levels=None, dtype=None): self.levels = levels + self.dtype = dtype def fit(self, df, y=None, **fit_params): return self @@ -133,6 +143,8 @@ def transform(self, df): for level in levels ] subarray = np.hstack(blocks) if blocks else np.empty((len(df), 0)) + if self.dtype: + subarray = subarray.astype(self.dtype) return subarray