Skip to content

Commit

Permalink
Fix and generalize tests to parquet or protobuf input types
Browse files Browse the repository at this point in the history
  • Loading branch information
sgreenbury committed Apr 15, 2024
1 parent 6f55a19 commit 1bc5b75
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 24 deletions.
52 changes: 30 additions & 22 deletions python/tests/test_builder.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,43 @@
import pytest
from test_utils import TEST_PATH, TEST_REGION
from uatk_spc.builder import Builder, unnest
from uatk_spc.reader import Reader

INPUT_TYPES = ["protobuf", "parquet"]

def test_unnest_data():
spc = Reader(TEST_PATH, TEST_REGION, backend="pandas")

@pytest.mark.parametrize("input_type", INPUT_TYPES)
def test_unnest_data(input_type):
spc = Reader(TEST_PATH, TEST_REGION, input_type, backend="pandas")
spc_unnested = unnest(spc.households, ["details"])
assert spc_unnested.columns.to_list() == [
"id",
"msoa",
"oa",
"members",
"hid",
"nssec8",
"accommodation_type",
"communal_type",
"num_rooms",
"central_heat",
"tenure",
"num_cars",
]
assert sorted(spc_unnested.columns.to_list()) == sorted(
[
"id",
"msoa11cd",
"oa11cd",
"members",
"hid",
"nssec8",
"accommodation_type",
"communal_type",
"num_rooms",
"central_heat",
"tenure",
"num_cars",
]
)


def test_add_households():
@pytest.mark.parametrize("input_type", INPUT_TYPES)
def test_add_households(input_type):
df_pandas = (
Builder(TEST_PATH, TEST_REGION, backend="pandas")
Builder(TEST_PATH, TEST_REGION, input_type, backend="pandas")
.add_households()
.unnest(["details"])
.build()
)
df_polars = (
Builder(TEST_PATH, TEST_REGION, backend="polars")
Builder(TEST_PATH, TEST_REGION, input_type, backend="polars")
.add_households()
.unnest(["details"])
.build()
Expand All @@ -42,7 +49,8 @@ def test_add_households():
assert set(df_pandas.columns.to_list()) == set(df_polars.columns)


def test_time_use_diaries_pandas():
@pytest.mark.parametrize("input_type", INPUT_TYPES)
def test_time_use_diaries_pandas(input_type):
features = {
"health": [
"bmi",
Expand All @@ -56,12 +64,12 @@ def test_time_use_diaries_pandas():
"employment": ["pwkstat", "salary_yearly"],
}
df_polars = (
Builder(TEST_PATH, TEST_REGION, backend="polars")
Builder(TEST_PATH, TEST_REGION, input_type, backend="polars")
.add_time_use_diaries(features)
.build()
)
df_pandas = (
Builder(TEST_PATH, TEST_REGION, backend="pandas")
Builder(TEST_PATH, TEST_REGION, input_type, backend="pandas")
.add_time_use_diaries(features)
.build()
)
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ def test_merge_people_and_time_use_diaries():
def test_merge_people_and_households():
spc = Reader(TEST_PATH, TEST_REGION)
merged = spc.merge_people_and_households()
assert merged.shape == (4991, 18)
assert merged.shape == (4991, 17)
2 changes: 1 addition & 1 deletion python/uatk_spc/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def add_time_use_diaries(
self.time_use_diaries,
pl.int_range(0, self.time_use_diaries.shape[0], eager=True)
.rename("index")
.cast(pl.UInt32)
.cast(people.dtypes[people.get_column_index(diary_type)])
.to_frame(),
],
how="horizontal",
Expand Down

0 comments on commit 1bc5b75

Please sign in to comment.