diff --git a/balance/stats_and_plots/weighted_comparisons_plots.py b/balance/stats_and_plots/weighted_comparisons_plots.py index adb3991..b48c6a5 100644 --- a/balance/stats_and_plots/weighted_comparisons_plots.py +++ b/balance/stats_and_plots/weighted_comparisons_plots.py @@ -773,9 +773,11 @@ def plotly_plot_qq( weighted_quantile( dict_of_dfs["target"][variable], np.arange(0, 1, 0.001), - dict_of_dfs["target"]["weight"] - if "weight" in dict_of_dfs["target"].columns - else None, + ( + dict_of_dfs["target"]["weight"] + if "weight" in dict_of_dfs["target"].columns + else None + ), )["col1"] ) @@ -798,9 +800,11 @@ def plotly_plot_qq( weighted_quantile( dict_of_dfs[name][variable], np.arange(0, 1, 0.001), - dict_of_dfs[name]["weight"] - if "weight" in dict_of_dfs[name].columns - else None, + ( + dict_of_dfs[name]["weight"] + if "weight" in dict_of_dfs[name].columns + else None + ), )["col1"] ), marker={ @@ -1068,9 +1072,11 @@ def plotly_plot_bar( df_plot_data = relative_frequency_table( dict_of_dfs[name], variable, - dict_of_dfs[name]["weight"] - if "weight" in dict_of_dfs[name].columns - else None, + ( + dict_of_dfs[name]["weight"] + if "weight" in dict_of_dfs[name].columns + else None + ), ) variable_specific_dict_of_plots[name] = go.Bar( diff --git a/balance/util.py b/balance/util.py index 816c0d2..ac0653e 100644 --- a/balance/util.py +++ b/balance/util.py @@ -1181,9 +1181,11 @@ def _return_type_creation_function(x: Any) -> Union[Callable, Any]: # Reapply the index for pd.Series r = [ - pd.Series(data, index=pd.Series(orig_data)[nonmissing_mask].index) - if isinstance(orig_data, pd.Series) - else data + ( + pd.Series(data, index=pd.Series(orig_data)[nonmissing_mask].index) + if isinstance(orig_data, pd.Series) + else data + ) for data, orig_data in zip(r, args) ] diff --git a/balance/weighting_methods/cbps.py b/balance/weighting_methods/cbps.py index cd5a226..638c32d 100644 --- a/balance/weighting_methods/cbps.py +++ b/balance/weighting_methods/cbps.py @@ -305,10 +305,7 @@ def _reverse_svd_and_centralization(beta, U, s, Vh, X_matrix_mean, X_matrix_std) beta_new = np.insert( beta_new, 0, - beta[ - 0, - ] - - np.matmul(X_matrix_mean, beta_new), + beta[0,] - np.matmul(X_matrix_mean, beta_new), ) return beta_new @@ -715,14 +712,14 @@ def cbps( # noqa # The following are the results of the optimizations "rescale_initial_result": rescale_initial_result, "balance_optimize_result": balance_optimize_result, - "gmm_optimize_result_glm_init": gmm_optimize_result_glm_init - if cbps_method == "over" - else None, + "gmm_optimize_result_glm_init": ( + gmm_optimize_result_glm_init if cbps_method == "over" else None + ), # pyre-fixme[61]: `gmm_optimize_result_bal_init` is undefined, or not # always defined. - "gmm_optimize_result_bal_init": gmm_optimize_result_bal_init - if cbps_method == "over" - else None, + "gmm_optimize_result_bal_init": ( + gmm_optimize_result_bal_init if cbps_method == "over" else None + ), }, } logger.info("Done cbps function") diff --git a/balance/weighting_methods/ipw.py b/balance/weighting_methods/ipw.py index f24d9ba..f9acf10 100644 --- a/balance/weighting_methods/ipw.py +++ b/balance/weighting_methods/ipw.py @@ -553,9 +553,7 @@ def ipw( raise NotImplementedError() logger.debug(f"fit['lambda_1se']: {fit['lambda_1se']}") - X_matrix_sample = X_matrix[ - :sample_n, - ].toarray() + X_matrix_sample = X_matrix[:sample_n,].toarray() logger.info(f"max_de: {max_de}") if max_de is not None: