Skip to content

Commit

Permalink
Add test for unnesting with overlapping columns
Browse files Browse the repository at this point in the history
  • Loading branch information
sgreenbury committed Apr 16, 2024
1 parent 7e08052 commit 29f6a3d
Showing 1 changed file with 49 additions and 18 deletions.
67 changes: 49 additions & 18 deletions python/tests/test_builder.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
import itertools

import pytest
from test_utils import TEST_PATH, TEST_REGION
from uatk_spc.builder import Builder, unnest_pandas
from uatk_spc.builder import Builder, unnest_pandas, unnest_polars
from uatk_spc.reader import Reader

BACKENDS = ["pandas", "polars"]
INPUT_TYPES = ["protobuf", "parquet"]
PRODUCT = itertools.product(*[INPUT_TYPES, BACKENDS])

EXPECTED_COLUMNS = [
"id",
"msoa11cd",
"oa11cd",
"members",
"hid",
"nssec8",
"accommodation_type",
"communal_type",
"num_rooms",
"central_heat",
"tenure",
"num_cars",
]


@pytest.mark.parametrize("input_type", INPUT_TYPES)
def test_unnest_data(input_type):
def test_unnest_pandas_data(input_type):
spc = Reader(TEST_PATH, TEST_REGION, input_type, backend="pandas")
spc_unnested = unnest_pandas(spc.households, ["details"])
assert sorted(spc_unnested.columns.to_list()) == sorted(
[
"id",
"msoa11cd",
"oa11cd",
"members",
"hid",
"nssec8",
"accommodation_type",
"communal_type",
"num_rooms",
"central_heat",
"tenure",
"num_cars",
]
)
assert sorted(spc_unnested.columns.to_list()) == sorted(EXPECTED_COLUMNS)


@pytest.mark.parametrize("input_type", INPUT_TYPES)
def test_unnest_polars_data(input_type):
spc = Reader(TEST_PATH, TEST_REGION, input_type, backend="polars")
spc_unnested = unnest_polars(spc.households, ["details"])
assert sorted(spc_unnested.columns) == sorted(EXPECTED_COLUMNS)


@pytest.mark.parametrize("input_type", INPUT_TYPES)
Expand All @@ -49,6 +60,26 @@ def test_add_households(input_type):
assert set(df_pandas.columns.to_list()) == set(df_polars.columns)


@pytest.mark.parametrize(("input_type", "backend"), PRODUCT)
def test_column_overlap(input_type, backend):
# Exception: ovelapping 'nssec8' without `rsuffix`
with pytest.raises(Exception):
df = (
Builder(TEST_PATH, TEST_REGION, input_type, backend)
.add_households()
.unnest(["demographics", "details"])
.build()
)
# Ok: ovelapping 'nssec8' with `rsuffix` specified
df = (
Builder(TEST_PATH, TEST_REGION, input_type, backend)
.add_households()
.unnest(["demographics", "details"], rsuffix="_household")
.build()
)
assert all([col in df.columns for col in ["nssec8", "nssec8_household"]])


@pytest.mark.parametrize("input_type", INPUT_TYPES)
def test_time_use_diaries_pandas(input_type):
features = {
Expand Down

0 comments on commit 29f6a3d

Please sign in to comment.