diff --git a/xarray_sql/core.py b/xarray_sql/core.py index 82f4a3d..288eae0 100644 --- a/xarray_sql/core.py +++ b/xarray_sql/core.py @@ -7,6 +7,7 @@ Row = t.List[t.Any] +# deprecated def get_columns(ds: xr.Dataset) -> t.List[str]: return list(ds.dims.keys()) + list(ds.data_vars.keys()) diff --git a/xarray_sql/df.py b/xarray_sql/df.py index addb93e..1844328 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -101,7 +101,7 @@ def pivot(b: Block) -> pd.DataFrame: f'{"_".join(list(ds.data_vars.keys()))}' ) - columns = core.get_columns(ds) + columns = pivot(blocks[0]).columns # TODO(#18): Is it possible to pass the length (known now) here? meta = {c: ds[c].dtype for c in columns} diff --git a/xarray_sql/df_test.py b/xarray_sql/df_test.py index 8688ad0..9007145 100644 --- a/xarray_sql/df_test.py +++ b/xarray_sql/df_test.py @@ -3,11 +3,37 @@ import dask.dataframe as dd import numpy as np +import pandas as pd import xarray as xr from .df import explode, read_xarray, block_slices +def rand_wx(start: str, end: str) -> xr.Dataset: + np.random.seed(42) + lat = np.linspace(-90, 90, num=720) + lon = np.linspace(-180, 180, num=1440) + time = pd.date_range(start, end, freq='H') + level = np.array([1000, 500], dtype=np.int32) + reference_time = pd.Timestamp(start) + temperature = 15 + 8 * np.random.randn(720, 1440, len(time), len(level)) + precipitation = 10 * np.random.rand(720, 1440, len(time), len(level)) + return xr.Dataset( + data_vars=dict( + temperature=(['lat', 'lon', 'time', 'level'], temperature), + precipitation=(['lat', 'lon', 'time', 'level'], precipitation), + ), + coords=dict( + lat=lat, + lon=lon, + time=time, + level=level, + reference_time=reference_time, + ), + attrs=dict(description='Random weather.') + ) + + class DaskTestCase(unittest.TestCase): def setUp(self) -> None: @@ -18,6 +44,7 @@ def setUp(self) -> None: self.air_small = self.air.isel( time=slice(0, 12), lat=slice(0, 11), lon=slice(0, 10) ).chunk(self.chunks) + self.randwx = rand_wx('1995-01-13T00', '1995-01-13T01') class ExplodeTest(DaskTestCase): @@ -84,6 +111,13 @@ def test_chunk_perf(self): self.assertIsNotNone(df) self.assertEqual(len(df), np.prod(list(self.air.dims.values()))) + def test_column_metadata_preserved(self): + try: + _ = read_xarray(self.randwx, chunks=dict(time=24)).compute() + except ValueError as e: + if 'The columns in the computed data do not match the columns in the provided metadata' in str(e): + self.fail('Column metadata is incorrect.') + if __name__ == '__main__': unittest.main()