From 330b67a3b91cc2de915b42136b1626fc8e2e904f Mon Sep 17 00:00:00 2001 From: Son Pham-Ba Date: Thu, 4 Jul 2024 16:54:33 +0200 Subject: [PATCH] feat: wip add tslong to tswide conversion --- tstore/backend.py | 18 ++++++++++ tstore/tests/test_tslong.py | 28 ++++++++++++++++ tstore/tslong/__init__.py | 2 +- tstore/tslong/dask.py | 20 ++++++++++- tstore/tslong/pandas.py | 17 +++++++++- tstore/tslong/polars.py | 6 +--- tstore/tslong/pyarrow.py | 6 +--- tstore/tslong/tslong.py | 23 ++++--------- tstore/tswide/__init__.py | 2 +- tstore/tswide/tswide.py | 66 +++++++++++++++++++++++++++++++++---- 10 files changed, 150 insertions(+), 38 deletions(-) diff --git a/tstore/backend.py b/tstore/backend.py index 84ae998..d7af8b6 100644 --- a/tstore/backend.py +++ b/tstore/backend.py @@ -314,3 +314,21 @@ def _change_series_backend_from_pyarrow( return series raise ValueError(f"Unsupported backend: {new_backend}") + + +def cast_column_to_large_string(df: DataFrame, col: str) -> DataFrame: + """Cast a column to a large string type.""" + if isinstance(df, (DaskDataFrame, PandasDataFrame)): + df[col] = df[col].astype("large_string[pyarrow]") + + elif isinstance(df, PolarsDataFrame): + df = df.cast({col: pl.String}) + + elif isinstance(df, PyArrowDataFrame): + schema = df.schema + field_index = schema.get_field_index(col) + schema = schema.remove(field_index) + schema = schema.insert(field_index, pa.field(col, pa.large_string())) + df = df.cast(target_schema=schema) + + return df diff --git a/tstore/tests/test_tslong.py b/tstore/tests/test_tslong.py index d876680..d63aad7 100644 --- a/tstore/tests/test_tslong.py +++ b/tstore/tests/test_tslong.py @@ -20,6 +20,10 @@ from tstore.tslong.pandas import TSLongPandas from tstore.tslong.polars import TSLongPolars from tstore.tslong.pyarrow import TSLongPyArrow +from tstore.tswide.dask import TSWideDask +from tstore.tswide.pandas import TSWidePandas +from tstore.tswide.polars import TSWidePolars +from tstore.tswide.pyarrow import TSWidePyArrow # Imported fixtures from conftest.py: # - dask_long_dataframe @@ -46,6 +50,13 @@ "pyarrow": TSDFPyArrow, } +tswide_classes = { + "dask": TSWideDask, + "pandas": TSWidePandas, + "polars": TSWidePolars, + "pyarrow": TSWidePyArrow, +} + # Functions #################################################################### @@ -231,3 +242,20 @@ def test_to_tsdf( np.testing.assert_array_equal(tsdf["tstore_id"], ["1", "2", "3", "4"]) np.testing.assert_array_equal(tsdf["static_var1"], ["A", "B", "C", "D"]) np.testing.assert_array_equal(tsdf["static_var2"], [1.0, 2.0, 3.0, 4.0]) + + +@pytest.mark.parametrize("backend", ["dask", "pandas", "polars", "pyarrow"]) +def test_to_tswide( + backend: str, + request, +) -> None: + """Test the to_tsdf function.""" + tslong_fixture_name = f"{backend}_tslong" + tslong = request.getfixturevalue(tslong_fixture_name) + tswide = tslong.to_tswide() + + assert isinstance(tswide, tswide_classes[backend]) + assert tswide._tstore_id_var == "tstore_id" + assert tswide._tstore_time_var == "time" + assert tswide._tstore_ts_vars == {"ts_var1": ["var1", "var2"], "ts_var2": ["var3", "var4"]} + assert tswide._tstore_static_vars == ["static_var1", "static_var2"] diff --git a/tstore/tslong/__init__.py b/tstore/tslong/__init__.py index e1c08dd..a7dffc6 100644 --- a/tstore/tslong/__init__.py +++ b/tstore/tslong/__init__.py @@ -11,7 +11,7 @@ from tstore.tslong.tslong import TSLong -def open_tslong(base_dir: Union[str, Path], *args, backend: Backend = "pandas", **kwargs): +def open_tslong(base_dir: Union[str, Path], *args, backend: Backend = "dask", **kwargs): """Read a TStore file structure as a TSLong object.""" ts_long_classes = { "dask": TSLongDask, diff --git a/tstore/tslong/dask.py b/tstore/tslong/dask.py index 30bde33..74c8b07 100644 --- a/tstore/tslong/dask.py +++ b/tstore/tslong/dask.py @@ -122,4 +122,22 @@ def _get_static_values(self) -> dict[str, list]: def to_tswide(self) -> "TSWideDask": """Convert the wrapper into a TSWideDask object.""" - raise NotImplementedError + from tstore.tswide.dask import TSWideDask + + df = self._obj + df = df.reset_index() + df[self._tstore_id_var] = df[self._tstore_id_var].astype("category").compute() + df = df.pivot_table( + index=self._tstore_time_var, + columns=self._tstore_id_var, + values=df.columns.difference([self._tstore_id_var]), + aggfunc="first", + ) + + return TSWideDask( + df, + id_var=self._tstore_id_var, + time_var=self._tstore_time_var, + ts_vars=self._tstore_ts_vars, + static_vars=self._tstore_static_vars, + ) diff --git a/tstore/tslong/pandas.py b/tstore/tslong/pandas.py index 0da8f84..7b257cc 100644 --- a/tstore/tslong/pandas.py +++ b/tstore/tslong/pandas.py @@ -162,4 +162,19 @@ def from_tstore( def to_tswide(self) -> "TSWidePandas": """Convert the wrapper into a TSWide object.""" - raise NotImplementedError + from tstore.tswide.pandas import TSWidePandas + + df = self._obj + df = df.pivot_table( + index=self._tstore_time_var, + columns=self._tstore_id_var, + aggfunc="first", + ) + + return TSWidePandas( + df, + id_var=self._tstore_id_var, + time_var=self._tstore_time_var, + ts_vars=self._tstore_ts_vars, + static_vars=self._tstore_static_vars, + ) diff --git a/tstore/tslong/polars.py b/tstore/tslong/polars.py index c8b44b9..871f674 100644 --- a/tstore/tslong/polars.py +++ b/tstore/tslong/polars.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: # To avoid circular imports - from tstore.tswide.polars import TSWidePolars + pass class TSLongPolars(TSLong): @@ -155,7 +155,3 @@ def from_tstore( # Conversion to polars return tslong_pyarrow.change_backend(new_backend="polars") - - def to_tswide(self) -> "TSWidePolars": - """Convert the wrapper into a TSWide object.""" - raise NotImplementedError diff --git a/tstore/tslong/pyarrow.py b/tstore/tslong/pyarrow.py index cc220a2..4c00b0e 100644 --- a/tstore/tslong/pyarrow.py +++ b/tstore/tslong/pyarrow.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: # To avoid circular imports - from tstore.tswide.pyarrow import TSWidePyArrow + pass class TSLongPyArrow(TSLong): @@ -95,10 +95,6 @@ def from_tstore( static_vars=static_vars, ) - def to_tswide(self) -> "TSWidePyArrow": - """Convert the wrapper into a TSWide object.""" - raise NotImplementedError - def _read_ts( fpath, diff --git a/tstore/tslong/tslong.py b/tstore/tslong/tslong.py index 9d46971..a2329fe 100644 --- a/tstore/tslong/tslong.py +++ b/tstore/tslong/tslong.py @@ -1,11 +1,7 @@ """Module defining the TSLong abstract wrapper.""" -from abc import abstractmethod from typing import TYPE_CHECKING, Optional -import polars as pl -import pyarrow as pa - from tstore.backend import ( Backend, DaskDataFrame, @@ -13,6 +9,7 @@ PandasDataFrame, PolarsDataFrame, PyArrowDataFrame, + cast_column_to_large_string, change_backend, re_set_dataframe_index, ) @@ -45,18 +42,7 @@ def __init__( Defaults to None, which will group all columns not in `static_vars` together. static_vars (list[str]): List of column names that are static across time. Defaults to None. """ - if isinstance(df, (DaskDataFrame, PandasDataFrame)): - df[id_var] = df[id_var].astype("large_string[pyarrow]") - - elif isinstance(df, PolarsDataFrame): - df = df.cast({id_var: pl.String}) - - elif isinstance(df, PyArrowDataFrame): - schema = df.schema - field_index = schema.get_field_index(id_var) - schema = schema.remove(field_index) - schema = schema.insert(field_index, pa.field(id_var, pa.large_string())) - df = df.cast(target_schema=schema) + df = cast_column_to_large_string(df, id_var) # Ensure correct index column df = re_set_dataframe_index(df, index_var=time_var) @@ -130,6 +116,9 @@ def to_tsdf(self) -> "TSDF": tsdf = dask_tsdf.change_backend(new_backend=self.current_backend) return tsdf - @abstractmethod def to_tswide(self) -> "TSWide": """Convert the wrapper into a TSWide object.""" + dask_tslong = self.change_backend(new_backend="dask") + dask_tswide = dask_tslong.to_tswide() + tswide = dask_tswide.change_backend(new_backend=self.current_backend) + return tswide diff --git a/tstore/tswide/__init__.py b/tstore/tswide/__init__.py index 790a1b7..2ce32bb 100644 --- a/tstore/tswide/__init__.py +++ b/tstore/tswide/__init__.py @@ -11,7 +11,7 @@ from tstore.tswide.tswide import TSWide -def open_tswide(base_dir: Union[str, Path], *args, backend: Backend = "pandas", **kwargs): +def open_tswide(base_dir: Union[str, Path], *args, backend: Backend = "dask", **kwargs): """Read a TStore file structure as a TSWide object.""" ts_wide_classes = { "dask": TSWideDask, diff --git a/tstore/tswide/tswide.py b/tstore/tswide/tswide.py index 86bbe1a..c07ebca 100644 --- a/tstore/tswide/tswide.py +++ b/tstore/tswide/tswide.py @@ -1,9 +1,15 @@ """Module defining the TSWide abstract wrapper.""" from abc import abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional -from tstore.backend import DaskDataFrame, DataFrame, PandasDataFrame, PolarsDataFrame, PyArrowDataFrame +from tstore.backend import ( + DaskDataFrame, + DataFrame, + PandasDataFrame, + PolarsDataFrame, + PyArrowDataFrame, +) from tstore.tswrapper.tswrapper import TSWrapper if TYPE_CHECKING: @@ -15,6 +21,52 @@ class TSWide(TSWrapper): """Abstract wrapper for a wide-form timeseries DataFrame.""" + def __init__( + self, + df: DataFrame, + id_var: str, + time_var: str = "time", + ts_vars: Optional[dict[str, list[str]]] = None, + static_vars: Optional[list[str]] = None, + ) -> None: + """Wrap a wide-form timeseries DataFrame as a TSWide object. + + Args: + df (DataFrame): DataFrame to wrap. + id_var (str): Name of the column containing the identifier variable. + time_var (str): Name of the column containing the time variable. Defaults to "time". + ts_vars (dict[str, list[str]]): Dictionary of named groups of column names. + Defaults to None, which will group all columns not in `static_vars` together. + static_vars (list[str]): List of column names that are static across time. Defaults to None. + """ + # TODO: Cast id_var to large string + # df = cast_column_to_large_string(df, id_var) + + # TODO: Ensure correct index column + # df = re_set_dataframe_index(df, index_var=time_var) + + super().__init__(df) + + if static_vars is None: + static_vars = [] + + if ts_vars is None: + ts_vars = { + "ts_variable": [ + col for col in df.columns if col != id_var and col != time_var and col not in static_vars + ], + } + + # Set attributes using __dict__ to not trigger __setattr__ + self.__dict__.update( + { + "_tstore_id_var": id_var, + "_tstore_time_var": time_var, + "_tstore_ts_vars": ts_vars, + "_tstore_static_vars": static_vars, + }, + ) + def __new__(cls, *args, **kwargs) -> "TSWide": """When calling TSWide() directly, return the appropriate subclass.""" if cls is TSWide: @@ -24,7 +76,7 @@ def __new__(cls, *args, **kwargs) -> "TSWide": return super().__new__(cls) @staticmethod - def wrap(df: DataFrame) -> "TSWide": + def wrap(df: DataFrame, *args, **kwargs) -> "TSWide": """Wrap a DataFrame in the appropriate TSWide subclass.""" # Lazy import to avoid circular imports from tstore.tswide.dask import TSWideDask @@ -33,16 +85,16 @@ def wrap(df: DataFrame) -> "TSWide": from tstore.tswide.pyarrow import TSWidePyArrow if isinstance(df, DaskDataFrame): - return TSWideDask(df) + return TSWideDask(df, *args, **kwargs) if isinstance(df, PandasDataFrame): - return TSWidePandas(df) + return TSWidePandas(df, *args, **kwargs) if isinstance(df, PolarsDataFrame): - return TSWidePolars(df) + return TSWidePolars(df, *args, **kwargs) if isinstance(df, PyArrowDataFrame): - return TSWidePyArrow(df) + return TSWidePyArrow(df, *args, **kwargs) type_path = f"{type(df).__module__}.{type(df).__qualname__}" raise TypeError(f"Cannot wrap type {type_path} as a TSWide object.")