diff --git a/python/uatk_spc/reader.py b/python/uatk_spc/reader.py index 543f430..ee8ee61 100644 --- a/python/uatk_spc/reader.py +++ b/python/uatk_spc/reader.py @@ -11,6 +11,17 @@ # - Add graph data structure reading for flows (e.g. into networkx) +# Type alias for a dataframe +DataFrame = pd.DataFrame | pl.DataFrame + + +def backend_error(backend: str) -> ValueError: + ValueError( + f"Backend: {backend} is not implemented. Use 'polars' or 'pandas' instead." + ) + + +# TODO: refactor with single class handling either proto or parquet class SPCReaderProto: """ A class for reading from protobuf into ready to use data structures. @@ -27,9 +38,9 @@ class SPCReaderProto: """ pop: synthpop_pb2.Population() - people: pl.DataFrame - households: pl.DataFrame - time_use_diaries: pl.DataFrame + people: DataFrame + households: DataFrame + time_use_diaries: DataFrame venues_per_activity: Dict[str, Any] info_per_msoa: Dict[str, Any] @@ -48,10 +59,8 @@ def __init__(self, path: str, region: str, backend="polars"): pop_as_dict["timeUseDiaries"] ) else: - raise ValueError( - f"Backend: {backend} is not implemented. Use 'polars' or 'pandas' " - f"instead." - ) + raise backend_error(backend) + self.venues_per_activity = pop_as_dict["venuesPerActivity"] self.info_per_msoa = pop_as_dict["infoPerMsoa"] @@ -76,13 +85,16 @@ class SPCReaderParquet: format. venues_per_activity (Dict[str, Any]): Venues per activity as a Python dict. info_per_msoa (Dict[str, Any]): Info per MSOA as a Python dict. + backend (str): DataFrame backend being used, must be either 'polars' or + 'pandas' """ - people: pl.DataFrame - households: pl.DataFrame - time_use_diaries: pl.DataFrame - venues_per_activity: pl.DataFrame + people: DataFrame + households: DataFrame + time_use_diaries: DataFrame + venues_per_activity: DataFrame info_per_msoa: dict + backend: str def __init__(self, path: str, region: str, backend="polars"): path_ = os.path.join(path, region) @@ -91,27 +103,23 @@ def __init__(self, path: str, region: str, backend="polars"): self.people = pl.read_parquet(path_ + "_people.pq") self.time_use_diaries = pl.read_parquet(path_ + "_time_use_diaries.pq") self.venues_per_activity = pl.read_parquet(path_ + "_venues.pq") + self.backend = "polars" elif backend == "pandas": self.households = pd.read_parquet(path_ + "_households.pq") self.people = pd.read_parquet(path_ + "_people.pq") self.time_use_diaries = pd.read_parquet(path_ + "_time_use_diaries.pq") self.venues_per_activity = pd.read_parquet(path_ + "_venues.pq") + self.backend = "pandas" else: - raise ValueError( - f"Backend: {backend} is not implemented. Use 'polars' or 'pandas' " - f"instead." - ) + raise backend_error(backend) + with open(path_ + "_info_per_msoa.json", "rb") as f: self.info_per_msoa = json.loads(f.read()) - def __summary( - self, df: pl.DataFrame - ) -> Dict[str, List[pl.datatypes.classes.DataTypeClass]]: + def __summary(self, df: DataFrame) -> Dict[str, List[Any]]: return dict(zip(df.columns, df.dtypes)) - def summary( - self, field: str - ) -> Dict[str, List[pl.datatypes.classes.DataTypeClass]] | None: + def summary(self, field: str) -> Dict[str, List[Any]] | None: """Provides a summary of the given SPC field. Args: @@ -146,19 +154,29 @@ def summary( ) ) - def merge(self, left: str, right: str, **kwargs) -> pl.DataFrame: + def merge(self, left: str, right: str, **kwargs) -> DataFrame: """Merges a left and right fields from SPC.""" # TODO: add implementation for any pair of fields pass - def merge_people_and_households(self) -> pl.DataFrame: - return self.people.unnest("identifiers").join( - self.households, left_on="household", right_on="id", how="left" - ) + def merge_people_and_households(self) -> DataFrame: + if self.backend == "polars": + return self.people.unnest("identifiers").join( + self.households, left_on="household", right_on="id", how="left" + ) + elif self.backend == "pandas": + # TODO: handle duplicate column names ("id") + return ( + self.people.drop(columns=["identifiers"]) + .join(pd.json_normalize(self.people["identifiers"])) + .merge(self.households, left_on="household", right_on="id", how="left") + ) + else: + raise backend_error(self.backend) def merge_people_and_time_use_diaries( self, people_features: Dict[str, List[str]], diary_type: str = "weekday_diaries" - ) -> pl.DataFrame: + ) -> DataFrame: people = ( self.people.unnest(people_features.keys()) .select(