Skip to content

Commit

Permalink
Merge pull request #61 from ihmeuw-msca/feature/count_split
Browse files Browse the repository at this point in the history
Count splitting update
  • Loading branch information
saalUW authored Jul 9, 2024
2 parents a13ffcc + 4ddb1a8 commit 5ed743f
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 37 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ requires = [

[project]
name = "pydisagg"
version = "0.5.0"
version = "0.5.1"
description = ""
readme = "README.md"
license = { text = "BSD 2-Clause License" }
Expand All @@ -27,7 +27,7 @@ classifiers = [
]
dependencies = [
"matplotlib",
"numpy",
"numpy<2.0.0",
"pandas",
"scipy",
"pydantic",
Expand Down
10 changes: 5 additions & 5 deletions src/pydisagg/disaggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def split_datapoint(
rate_pattern: NDArray,
observed_total_se: Optional[float] = None,
model: Optional[DisaggModel] = LogOdds_model(),
output_type: Literal["total", "rate"] = "total",
output_type: Literal["count", "rate"] = "count",
normalize_pop_for_average_type_obs: bool = False,
pattern_covariance: Optional[NDArray] = None,
) -> Union[tuple, NDArray]:
Expand Down Expand Up @@ -71,7 +71,7 @@ def split_datapoint(
If observed_total_se is given, then returns a tuple
(point_estimate,standard_error)
"""
if output_type not in ["total", "rate"]:
if output_type not in ["count", "rate"]:
raise ValueError("output_type must be one of either 'total' or 'rate'")

if normalize_pop_for_average_type_obs is True:
Expand All @@ -81,7 +81,7 @@ def split_datapoint(
else:
processed_bucket_populations = bucket_populations.copy()

if output_type == "total":
if output_type == "count":
point_estimates = model.split_to_counts(
observed_total, rate_pattern, processed_bucket_populations
)
Expand Down Expand Up @@ -160,7 +160,7 @@ def split_dataframe(
rate_patterns: DataFrame,
use_se: Optional[bool] = False,
model: Optional[DisaggModel] = LogOdds_model(),
output_type: Literal["total", "rate"] = "total",
output_type: Literal["count", "rate"] = "count",
demographic_id_columns: Optional[list] = None,
normalize_pop_for_average_type_obs: bool = False,
) -> DataFrame:
Expand Down Expand Up @@ -220,7 +220,7 @@ def split_dataframe(
point estimate and standard error for the estimate for each group is given.
"""
if (normalize_pop_for_average_type_obs is True) and (
output_type == "total"
output_type == "count"
):
raise Warning(
"Normalizing populations may not be appropriate here, as we are working with a total"
Expand Down
59 changes: 42 additions & 17 deletions src/pydisagg/ihme/splitter/age_splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any
from typing import Any, Literal

import numpy as np
import pandas as pd
Expand All @@ -25,14 +25,19 @@ class AgeDataConfig(BaseModel):
age_upr: str
val: str
val_sd: str
# sample_size: str | None

@property
def columns(self) -> list[str]:
return list(
set(
self.index + [self.age_lwr, self.age_upr, self.val, self.val_sd]
)
)
base_columns = self.index + [
self.age_lwr,
self.age_upr,
self.val,
self.val_sd,
]
# if self.sample_size is not None:
# base_columns.append(self.sample_size)
return list(set(base_columns))


class AgePopulationConfig(BaseModel):
Expand Down Expand Up @@ -129,7 +134,7 @@ def model_post_init(self, __context: Any) -> None:
)

def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame:
name = "data"
name = "Parsing Data"
validate_columns(data, self.data.columns, name)

data = data[self.data.columns].copy()
Expand All @@ -147,7 +152,7 @@ def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame:
def parse_pattern(
self, data: DataFrame, pattern: DataFrame, positive_strict: bool
) -> DataFrame:
name = "pattern"
name = "Parsing Pattern"

if not all(
col in pattern.columns
Expand Down Expand Up @@ -181,10 +186,11 @@ def parse_pattern(
name,
)

pattern_copy = pattern.copy()
rename_map = self.pattern.apply_prefix()
pattern.rename(columns=rename_map, inplace=True)
pattern_copy.rename(columns=rename_map, inplace=True)

data_with_pattern = self._merge_with_pattern(data, pattern)
data_with_pattern = self._merge_with_pattern(data, pattern_copy)

validate_noindexdiff(data, data_with_pattern, self.data.index, name)
validate_pat_coverage(
Expand Down Expand Up @@ -218,18 +224,19 @@ def _merge_with_pattern(
def parse_population(
self, data: DataFrame, population: DataFrame
) -> DataFrame:
name = "population"
name = "Parsing Population"
validate_columns(population, self.population.columns, name)

population = population[self.population.columns].copy()

validate_index(population, self.population.index, name)
validate_nonan(population, name)

pop_copy = population.copy()
rename_map = self.population.apply_prefix()
population.rename(columns=rename_map, inplace=True)
pop_copy.rename(columns=rename_map, inplace=True)

data_with_population = self._merge_with_population(data, population)
data_with_population = self._merge_with_population(data, pop_copy)

validate_noindexdiff(
data,
Expand Down Expand Up @@ -279,15 +286,25 @@ def _align_pattern_and_population(self, data: DataFrame) -> DataFrame:
f"* ({self.data.age_upr} - {self.pattern.age_lwr})"
)

# Not used right now, but useful in checking how we handle population partitioning
# Can be used to split sample sizes using the pseudo-proportion
data[self.population.val + "_total"] = data.groupby(self.data.index)[
self.population.val + "_aligned"
].transform(lambda x: x.sum())
data[self.population.val + "_proportion"] = (
data[self.population.val + "_aligned"]
/ data[self.population.val + "_total"]
)

return data

def split(
self,
data: DataFrame,
pattern: DataFrame,
population: DataFrame,
model: str = "rate",
output_type: str = "rate",
model: Literal["rate", "logodds"] = "rate",
output_type: Literal["rate", "count"] = "rate",
propagate_zeros=False,
) -> DataFrame:
"""
Expand All @@ -308,7 +325,7 @@ def split(
output_type : str, optional
The type of output to be returned, by default "rate".
propagate_zeros : Bool, optional
Whether to propagate pre-split zeros as post split zeros. Default true
Whether to propagate pre-split zeros as post split zeros. Default false
Returns
-------
Expand Down Expand Up @@ -372,6 +389,11 @@ def split(
)

data_group = data.groupby(self.data.index)
if output_type == "count":
pop_normalize = False
elif output_type == "rate":
pop_normalize = True

for key, data_sub in data_group:
split_result, SE = split_datapoint(
observed_total=data_sub[self.data.val].iloc[0],
Expand All @@ -381,7 +403,7 @@ def split(
rate_pattern=data_sub[self.pattern.val + "_aligned"].to_numpy(),
model=model_instance,
output_type=output_type, # type: ignore, this is handeled by model_mapping
normalize_pop_for_average_type_obs=True,
normalize_pop_for_average_type_obs=pop_normalize,
observed_total_se=data_sub[self.data.val_sd].iloc[0],
pattern_covariance=np.diag(
data_sub[self.pattern.val_sd + "_aligned"].to_numpy() ** 2
Expand All @@ -396,4 +418,7 @@ def split(
self.pattern.remove_prefix()
self.population.remove_prefix()

# Something like this can be implemented for sample size split
# data["split_"+ self.data.sample_size] = data[self.data.sample_size] * data[self.population.val + "_proportion"]

return data
25 changes: 15 additions & 10 deletions src/pydisagg/ihme/splitter/sex_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
validate_index,
validate_nonan,
validate_positive,
validate_noindexdiff,
)


Expand Down Expand Up @@ -90,7 +91,7 @@ def _merge_with_pattern(
return data_with_pattern

def parse_data(self, data: DataFrame) -> DataFrame:
name = "data"
name = "When parsing, data"
validate_columns(data, self.data.columns, name)
data = data[self.data.columns].copy()
validate_index(data, self.data.index, name)
Expand All @@ -99,8 +100,7 @@ def parse_data(self, data: DataFrame) -> DataFrame:
return data

def parse_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame:
name = "pattern"

name = "When parsing, pattern"
if not all(
col in pattern.columns
for col in [self.pattern.val, self.pattern.val_sd]
Expand Down Expand Up @@ -132,12 +132,13 @@ def get_population_by_sex(self, population, sex_value):
def parse_population(
self, data: DataFrame, population: DataFrame
) -> DataFrame:
name = "population"
name = "When parsing, population"
validate_columns(population, self.population.columns, name)

male_population = self.get_population_by_sex(
population, self.population.sex_m
)

female_population = self.get_population_by_sex(
population, self.population.sex_f
)
Expand All @@ -152,11 +153,14 @@ def parse_population(
data_with_population = self._merge_with_population(
data, male_population, "m_pop"
)

data_with_population = self._merge_with_population(
data_with_population, female_population, "f_pop"
)

validate_columns(data_with_population, ["m_pop", "f_pop"], name)
validate_nonan(data_with_population, name)
validate_noindexdiff(data, data_with_population, self.data.index, name)
return data_with_population

def _merge_with_population(
Expand All @@ -177,15 +181,16 @@ def split(
pattern: DataFrame,
population: DataFrame,
model: str = "rate",
output_type: str = "rate",
) -> DataFrame:
data = self.parse_data(data)
data = self.parse_pattern(data, pattern)
data = self.parse_population(data, population)

if model != "rate":
raise ValueError(
"Only 'rate' model is currently supported for SexSplitter"
)
if output_type == "count":
pop_normalize = False
elif output_type == "rate":
pop_normalize = True

def sex_split_row(row):
split_result, SE = split_datapoint(
Expand All @@ -195,8 +200,8 @@ def sex_split_row(row):
# This is from sex_pattern
rate_pattern=np.array([1.0, row[self.pattern.val]]),
model=RateMultiplicativeModel(),
output_type="rate",
normalize_pop_for_average_type_obs=True,
output_type=output_type,
normalize_pop_for_average_type_obs=pop_normalize,
# This is from the data
observed_total_se=row[self.data.val_sd],
# This is from sex_pattern
Expand Down
20 changes: 18 additions & 2 deletions src/pydisagg/ihme/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
def validate_columns(df: DataFrame, columns: list[str], name: str) -> None:
missing = [col for col in columns if col not in df.columns]
if missing:
raise KeyError(f"{name} has missing columns: {missing}")
error_message = (
f"{name} has missing columns: {len(missing)} columns are missing.\n"
)
error_message += f"Missing columns: {', '.join(missing)}\n"
if len(missing) > 5:
error_message += "First 5 missing columns: \n"
error_message += ", \n".join(missing[:5])
error_message += "\n"
raise KeyError(error_message)


def validate_index(df: DataFrame, index: list[str], name: str) -> None:
Expand All @@ -26,7 +34,15 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None:
def validate_nonan(df: DataFrame, name: str) -> None:
nan_columns = df.columns[df.isna().any(axis=0)].to_list()
if nan_columns:
raise ValueError(f"{name} has NaN values in columns: {nan_columns}")
error_message = (
f"{name} has NaN values in {len(nan_columns)} columns. \n"
)
error_message += f"Columns with NaN values: {', '.join(nan_columns)}\n"
if len(nan_columns) > 5:
error_message += "First 5 columns with NaN values: \n"
error_message += ", \n".join(nan_columns[:5])
error_message += "\n"
raise ValueError(error_message)


def validate_positive(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_disaggregate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_count_model_consistency(model):
rate_pattern,
measurement_SE,
model,
output_type="total",
output_type="count",
)
assert_approx_equal(measured_total, np.sum(result))
assert_approx_equal(measurement_SE, np.sum(SE))
Expand Down

0 comments on commit 5ed743f

Please sign in to comment.