diff --git a/pyproject.toml b/pyproject.toml index 88d1f22..082a4e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ [project] name = "pydisagg" -version = "0.5.0" +version = "0.5.1" description = "" readme = "README.md" license = { text = "BSD 2-Clause License" } @@ -27,7 +27,7 @@ classifiers = [ ] dependencies = [ "matplotlib", - "numpy", + "numpy<2.0.0", "pandas", "scipy", "pydantic", diff --git a/src/pydisagg/disaggregate.py b/src/pydisagg/disaggregate.py index 4b1052f..5018db1 100644 --- a/src/pydisagg/disaggregate.py +++ b/src/pydisagg/disaggregate.py @@ -17,7 +17,7 @@ def split_datapoint( rate_pattern: NDArray, observed_total_se: Optional[float] = None, model: Optional[DisaggModel] = LogOdds_model(), - output_type: Literal["total", "rate"] = "total", + output_type: Literal["count", "rate"] = "count", normalize_pop_for_average_type_obs: bool = False, pattern_covariance: Optional[NDArray] = None, ) -> Union[tuple, NDArray]: @@ -71,7 +71,7 @@ def split_datapoint( If observed_total_se is given, then returns a tuple (point_estimate,standard_error) """ - if output_type not in ["total", "rate"]: + if output_type not in ["count", "rate"]: raise ValueError("output_type must be one of either 'total' or 'rate'") if normalize_pop_for_average_type_obs is True: @@ -81,7 +81,7 @@ def split_datapoint( else: processed_bucket_populations = bucket_populations.copy() - if output_type == "total": + if output_type == "count": point_estimates = model.split_to_counts( observed_total, rate_pattern, processed_bucket_populations ) @@ -160,7 +160,7 @@ def split_dataframe( rate_patterns: DataFrame, use_se: Optional[bool] = False, model: Optional[DisaggModel] = LogOdds_model(), - output_type: Literal["total", "rate"] = "total", + output_type: Literal["count", "rate"] = "count", demographic_id_columns: Optional[list] = None, normalize_pop_for_average_type_obs: bool = False, ) -> DataFrame: @@ -220,7 +220,7 @@ def split_dataframe( point estimate and standard error for the estimate for each group is given. """ if (normalize_pop_for_average_type_obs is True) and ( - output_type == "total" + output_type == "count" ): raise Warning( "Normalizing populations may not be appropriate here, as we are working with a total" diff --git a/src/pydisagg/ihme/splitter/age_splitter.py b/src/pydisagg/ihme/splitter/age_splitter.py index 1b53701..219ce1a 100644 --- a/src/pydisagg/ihme/splitter/age_splitter.py +++ b/src/pydisagg/ihme/splitter/age_splitter.py @@ -1,5 +1,5 @@ import warnings -from typing import Any +from typing import Any, Literal import numpy as np import pandas as pd @@ -25,14 +25,19 @@ class AgeDataConfig(BaseModel): age_upr: str val: str val_sd: str + # sample_size: str | None @property def columns(self) -> list[str]: - return list( - set( - self.index + [self.age_lwr, self.age_upr, self.val, self.val_sd] - ) - ) + base_columns = self.index + [ + self.age_lwr, + self.age_upr, + self.val, + self.val_sd, + ] + # if self.sample_size is not None: + # base_columns.append(self.sample_size) + return list(set(base_columns)) class AgePopulationConfig(BaseModel): @@ -129,7 +134,7 @@ def model_post_init(self, __context: Any) -> None: ) def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame: - name = "data" + name = "Parsing Data" validate_columns(data, self.data.columns, name) data = data[self.data.columns].copy() @@ -147,7 +152,7 @@ def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame: def parse_pattern( self, data: DataFrame, pattern: DataFrame, positive_strict: bool ) -> DataFrame: - name = "pattern" + name = "Parsing Pattern" if not all( col in pattern.columns @@ -181,10 +186,11 @@ def parse_pattern( name, ) + pattern_copy = pattern.copy() rename_map = self.pattern.apply_prefix() - pattern.rename(columns=rename_map, inplace=True) + pattern_copy.rename(columns=rename_map, inplace=True) - data_with_pattern = self._merge_with_pattern(data, pattern) + data_with_pattern = self._merge_with_pattern(data, pattern_copy) validate_noindexdiff(data, data_with_pattern, self.data.index, name) validate_pat_coverage( @@ -218,7 +224,7 @@ def _merge_with_pattern( def parse_population( self, data: DataFrame, population: DataFrame ) -> DataFrame: - name = "population" + name = "Parsing Population" validate_columns(population, self.population.columns, name) population = population[self.population.columns].copy() @@ -226,10 +232,11 @@ def parse_population( validate_index(population, self.population.index, name) validate_nonan(population, name) + pop_copy = population.copy() rename_map = self.population.apply_prefix() - population.rename(columns=rename_map, inplace=True) + pop_copy.rename(columns=rename_map, inplace=True) - data_with_population = self._merge_with_population(data, population) + data_with_population = self._merge_with_population(data, pop_copy) validate_noindexdiff( data, @@ -279,6 +286,16 @@ def _align_pattern_and_population(self, data: DataFrame) -> DataFrame: f"* ({self.data.age_upr} - {self.pattern.age_lwr})" ) + # Not used right now, but useful in checking how we handle population partitioning + # Can be used to split sample sizes using the pseudo-proportion + data[self.population.val + "_total"] = data.groupby(self.data.index)[ + self.population.val + "_aligned" + ].transform(lambda x: x.sum()) + data[self.population.val + "_proportion"] = ( + data[self.population.val + "_aligned"] + / data[self.population.val + "_total"] + ) + return data def split( @@ -286,8 +303,8 @@ def split( data: DataFrame, pattern: DataFrame, population: DataFrame, - model: str = "rate", - output_type: str = "rate", + model: Literal["rate", "logodds"] = "rate", + output_type: Literal["rate", "count"] = "rate", propagate_zeros=False, ) -> DataFrame: """ @@ -308,7 +325,7 @@ def split( output_type : str, optional The type of output to be returned, by default "rate". propagate_zeros : Bool, optional - Whether to propagate pre-split zeros as post split zeros. Default true + Whether to propagate pre-split zeros as post split zeros. Default false Returns ------- @@ -372,6 +389,11 @@ def split( ) data_group = data.groupby(self.data.index) + if output_type == "count": + pop_normalize = False + elif output_type == "rate": + pop_normalize = True + for key, data_sub in data_group: split_result, SE = split_datapoint( observed_total=data_sub[self.data.val].iloc[0], @@ -381,7 +403,7 @@ def split( rate_pattern=data_sub[self.pattern.val + "_aligned"].to_numpy(), model=model_instance, output_type=output_type, # type: ignore, this is handeled by model_mapping - normalize_pop_for_average_type_obs=True, + normalize_pop_for_average_type_obs=pop_normalize, observed_total_se=data_sub[self.data.val_sd].iloc[0], pattern_covariance=np.diag( data_sub[self.pattern.val_sd + "_aligned"].to_numpy() ** 2 @@ -396,4 +418,7 @@ def split( self.pattern.remove_prefix() self.population.remove_prefix() + # Something like this can be implemented for sample size split + # data["split_"+ self.data.sample_size] = data[self.data.sample_size] * data[self.population.val + "_proportion"] + return data diff --git a/src/pydisagg/ihme/splitter/sex_splitter.py b/src/pydisagg/ihme/splitter/sex_splitter.py index 66e50c1..7dd2935 100644 --- a/src/pydisagg/ihme/splitter/sex_splitter.py +++ b/src/pydisagg/ihme/splitter/sex_splitter.py @@ -10,6 +10,7 @@ validate_index, validate_nonan, validate_positive, + validate_noindexdiff, ) @@ -90,7 +91,7 @@ def _merge_with_pattern( return data_with_pattern def parse_data(self, data: DataFrame) -> DataFrame: - name = "data" + name = "When parsing, data" validate_columns(data, self.data.columns, name) data = data[self.data.columns].copy() validate_index(data, self.data.index, name) @@ -99,8 +100,7 @@ def parse_data(self, data: DataFrame) -> DataFrame: return data def parse_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame: - name = "pattern" - + name = "When parsing, pattern" if not all( col in pattern.columns for col in [self.pattern.val, self.pattern.val_sd] @@ -132,12 +132,13 @@ def get_population_by_sex(self, population, sex_value): def parse_population( self, data: DataFrame, population: DataFrame ) -> DataFrame: - name = "population" + name = "When parsing, population" validate_columns(population, self.population.columns, name) male_population = self.get_population_by_sex( population, self.population.sex_m ) + female_population = self.get_population_by_sex( population, self.population.sex_f ) @@ -152,11 +153,14 @@ def parse_population( data_with_population = self._merge_with_population( data, male_population, "m_pop" ) + data_with_population = self._merge_with_population( data_with_population, female_population, "f_pop" ) + validate_columns(data_with_population, ["m_pop", "f_pop"], name) validate_nonan(data_with_population, name) + validate_noindexdiff(data, data_with_population, self.data.index, name) return data_with_population def _merge_with_population( @@ -177,15 +181,16 @@ def split( pattern: DataFrame, population: DataFrame, model: str = "rate", + output_type: str = "rate", ) -> DataFrame: data = self.parse_data(data) data = self.parse_pattern(data, pattern) data = self.parse_population(data, population) - if model != "rate": - raise ValueError( - "Only 'rate' model is currently supported for SexSplitter" - ) + if output_type == "count": + pop_normalize = False + elif output_type == "rate": + pop_normalize = True def sex_split_row(row): split_result, SE = split_datapoint( @@ -195,8 +200,8 @@ def sex_split_row(row): # This is from sex_pattern rate_pattern=np.array([1.0, row[self.pattern.val]]), model=RateMultiplicativeModel(), - output_type="rate", - normalize_pop_for_average_type_obs=True, + output_type=output_type, + normalize_pop_for_average_type_obs=pop_normalize, # This is from the data observed_total_se=row[self.data.val_sd], # This is from sex_pattern diff --git a/src/pydisagg/ihme/validator.py b/src/pydisagg/ihme/validator.py index b0b7dbb..9305456 100644 --- a/src/pydisagg/ihme/validator.py +++ b/src/pydisagg/ihme/validator.py @@ -6,7 +6,15 @@ def validate_columns(df: DataFrame, columns: list[str], name: str) -> None: missing = [col for col in columns if col not in df.columns] if missing: - raise KeyError(f"{name} has missing columns: {missing}") + error_message = ( + f"{name} has missing columns: {len(missing)} columns are missing.\n" + ) + error_message += f"Missing columns: {', '.join(missing)}\n" + if len(missing) > 5: + error_message += "First 5 missing columns: \n" + error_message += ", \n".join(missing[:5]) + error_message += "\n" + raise KeyError(error_message) def validate_index(df: DataFrame, index: list[str], name: str) -> None: @@ -26,7 +34,15 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None: def validate_nonan(df: DataFrame, name: str) -> None: nan_columns = df.columns[df.isna().any(axis=0)].to_list() if nan_columns: - raise ValueError(f"{name} has NaN values in columns: {nan_columns}") + error_message = ( + f"{name} has NaN values in {len(nan_columns)} columns. \n" + ) + error_message += f"Columns with NaN values: {', '.join(nan_columns)}\n" + if len(nan_columns) > 5: + error_message += "First 5 columns with NaN values: \n" + error_message += ", \n".join(nan_columns[:5]) + error_message += "\n" + raise ValueError(error_message) def validate_positive( diff --git a/tests/test_disaggregate_api.py b/tests/test_disaggregate_api.py index fac2348..a6c967e 100644 --- a/tests/test_disaggregate_api.py +++ b/tests/test_disaggregate_api.py @@ -27,7 +27,7 @@ def test_count_model_consistency(model): rate_pattern, measurement_SE, model, - output_type="total", + output_type="count", ) assert_approx_equal(measured_total, np.sum(result)) assert_approx_equal(measurement_SE, np.sum(SE))