diff --git a/python/uatk_spc/reader.py b/python/uatk_spc/reader.py index ced2f30..e89de7e 100644 --- a/python/uatk_spc/reader.py +++ b/python/uatk_spc/reader.py @@ -1,3 +1,4 @@ +import os from typing import Any, Dict from google.protobuf.json_format import MessageToDict import polars as pl @@ -35,9 +36,9 @@ class SPCReaderProto: venues_per_activity: Dict[str, Any] info_per_msoa: Dict[str, Any] - def __init__(self, path: str): + def __init__(self, path: str, region: str): """Init from a path and region.""" - self.pop = SPCReaderProto.read_pop(path) + self.pop = SPCReaderProto.read_pop(os.path.join(path, region + ".pb")) pop_as_dict = MessageToDict(self.pop, including_default_value_fields=True) self.households = pl.from_records(pop_as_dict["households"]) self.people = pl.from_records(pop_as_dict["people"]) @@ -73,8 +74,8 @@ class SPCReaderParquet: venues_per_activity: pl.DataFrame info_per_msoa: dict - def __init__(self, path: str): - path_ = path.split(".pb")[0] + def __init__(self, path: str, region: str): + path_ = os.path.join(path, region) self.households = pl.read_parquet(path_ + "_households.pq") self.people = pl.read_parquet(path_ + "_people.pq") self.time_use_diaries = pl.read_parquet(path_ + "_time_use_diaries.pq")