From 8b9e859a68c2567a2ce7e5e30353f2f66451f718 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Tue, 6 Jul 2021 11:41:19 +0100 Subject: [PATCH] lisa.trace: Save dataframe custom attrs to parquet Use a workaround until this ENH is implemented: https://github.com/pandas-dev/pandas/issues/20521 --- lisa/trace.py | 45 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/lisa/trace.py b/lisa/trace.py index 2aa71d7215..76ac6299f2 100644 --- a/lisa/trace.py +++ b/lisa/trace.py @@ -47,6 +47,7 @@ import pandas as pd from pandas.api.types import is_numeric_dtype import pyarrow.lib +import pyarrow.parquet import devlib @@ -2826,11 +2827,51 @@ def _update_swap_cost(self, data, swap_cost, mem_usage, swap_size): def _is_written_to_swap(self, pd_desc): return pd_desc.normal_form in self._swap_content + @staticmethod + def _data_to_parquet(data, path, **kwargs): + """ + Equivalent to `df.to_parquet(...)` but workaround until pandas can save + attrs to parquet on its own: ENH request on pandas: + https://github.com/pandas-dev/pandas/issues/20521 + + Workaround: + https://github.com/pandas-dev/pandas/pull/20534#issuecomment-453236538 + """ + if isinstance(data, pd.DataFrame): + # Data must be convertible to bytes so we dump them as JSON + attrs = json.dumps(data.attrs) + table = pyarrow.Table.from_pandas(data) + updated_metadata = dict( + table.schema.metadata or {}, + lisa=attrs, + ) + table = table.replace_schema_metadata(updated_metadata) + pyarrow.parquet.write_table(table, path, **kwargs) + else: + data.to_parquet(path, **kwargs) + + @staticmethod + def _data_from_parquet(path): + """ + Equivalent to `pd.read_parquet(...)` but also load the metadata back + into dataframes's attrs + """ + data = pd.read_parquet(path) + + # Load back LISA metadata into "df.attrs", as they were written in + # _data_to_parquet() + if isinstance(data, pd.DataFrame): + schema = pyarrow.parquet.read_schema(path) + attrs = schema.metadata.get(b'lisa', '{}') + data.attrs = json.loads(attrs) + + return data + @classmethod def _write_data(cls, data, path): if cls.DATAFRAME_SWAP_FORMAT == 'parquet': # Snappy compression seems very fast - data.to_parquet(path, compression='snappy', index=True) + cls._data_to_parquet(data, path, compression='snappy') else: raise ValueError(f'Dataframe swap format "{cls.DATAFRAME_SWAP_FORMAT}" not handled') @@ -2974,7 +3015,7 @@ def fetch(self, pd_desc, insert=True): # Try to load the dataframe from that path try: if self.DATAFRAME_SWAP_FORMAT == 'parquet': - data = pd.read_parquet(path) + data = self._data_from_parquet(path) else: raise ValueError(f'Dataframe swap format "{self.DATAFRAME_SWAP_FORMAT}" not handled') except (OSError, pyarrow.lib.ArrowIOError):