diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index fd4c6e42d4ae8..57fd7c61d6bd7 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -143,6 +143,44 @@ def read(self, path, columns=None, **kwargs): raise AbstractMethodError(self) +def _pyarrow_write_attrs(table: Any, df: DataFrame) -> Any: + """ + .. versionadded:: 1.3 + + Copy attts from pandas.DataFrame and pandas.Series to + schema metadata in pyarrow.Table. + """ + schema_metadata = table.schema.metadata or {} + pandas_metadata = json.loads(schema_metadata.get(b"pandas", "{}")) + column_attrs = {} + for col in df.columns: + attrs = df[col].attrs + if not attrs or not isinstance(col, str): + continue + column_attrs[col] = attrs + pandas_metadata.update( + attrs=df.attrs, + column_attrs=column_attrs, + ) + schema_metadata[b"pandas"] = json.dumps(pandas_metadata) + return table.replace_schema_metadata(schema_metadata) + + +def _pyarrow_read_attrs(table: Any, df: DataFrame) -> None: + """ + .. versionadded:: 1.3 + + Copy schema metadata from pyarrow.Table + to attrs in pandas.DataFrame and pandas.Series. + """ + schema_metadata = table.schema.metadata or {} + pandas_metadata = json.loads(schema_metadata.get(b"pandas", "{}")) + df.attrs = pandas_metadata.get("attrs", {}) + col_attrs = pandas_metadata.get("column_attrs", {}) + for col in df.columns: + df[col].attrs = col_attrs.get(col, {}) + + class PyArrowImpl(BaseImpl): def __init__(self): import_optional_dependency( @@ -155,32 +193,6 @@ def __init__(self): self.api = pyarrow - @staticmethod - def _write_attrs(table, df: DataFrame): - schema_metadata = table.schema.metadata or {} - pandas_metadata = json.loads(schema_metadata.get(b"pandas", "{}")) - column_attrs = {} - for col in df.columns: - attrs = df[col].attrs - if not attrs or not isinstance(col, str): - continue - column_attrs[col] = attrs - pandas_metadata.update( - attrs=df.attrs, - column_attrs=column_attrs, - ) - schema_metadata[b"pandas"] = json.dumps(pandas_metadata) - return table.replace_schema_metadata(schema_metadata) - - @staticmethod - def _read_attrs(table, df: DataFrame): - schema_metadata = table.schema.metadata or {} - pandas_metadata = json.loads(schema_metadata.get(b"pandas", "{}")) - df.attrs = pandas_metadata.get("attrs", {}) - col_attrs = pandas_metadata.get("column_attrs", {}) - for col in df.columns: - df[col].attrs = col_attrs.get(col, {}) - def write( self, df: DataFrame, @@ -198,7 +210,7 @@ def write( from_pandas_kwargs["preserve_index"] = index table = self.api.Table.from_pandas(df, **from_pandas_kwargs) - table = self._write_attrs(table, df) + table = _pyarrow_write_attrs(table, df) path_or_handle, handles, kwargs["filesystem"] = _get_path_or_handle( path, @@ -268,7 +280,7 @@ def read( path_or_handle, columns=columns, **kwargs ) result = table.to_pandas(**to_pandas_kwargs) - self._read_attrs(table, result) + _pyarrow_read_attrs(table, result) if manager == "array": result = result._as_manager("array", copy=False) return result