Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (10/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447733

fbshipit-source-id: 11ac742489579bb1dfec025514aa956159cf4959
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent 72e54bc commit 592483a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 25 deletions.
24 changes: 15 additions & 9 deletions balance/stats_and_plots/weighted_comparisons_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,9 +773,11 @@ def plotly_plot_qq(
weighted_quantile(
dict_of_dfs["target"][variable],
np.arange(0, 1, 0.001),
dict_of_dfs["target"]["weight"]
if "weight" in dict_of_dfs["target"].columns
else None,
(
dict_of_dfs["target"]["weight"]
if "weight" in dict_of_dfs["target"].columns
else None
),
)["col1"]
)

Expand All @@ -798,9 +800,11 @@ def plotly_plot_qq(
weighted_quantile(
dict_of_dfs[name][variable],
np.arange(0, 1, 0.001),
dict_of_dfs[name]["weight"]
if "weight" in dict_of_dfs[name].columns
else None,
(
dict_of_dfs[name]["weight"]
if "weight" in dict_of_dfs[name].columns
else None
),
)["col1"]
),
marker={
Expand Down Expand Up @@ -1068,9 +1072,11 @@ def plotly_plot_bar(
df_plot_data = relative_frequency_table(
dict_of_dfs[name],
variable,
dict_of_dfs[name]["weight"]
if "weight" in dict_of_dfs[name].columns
else None,
(
dict_of_dfs[name]["weight"]
if "weight" in dict_of_dfs[name].columns
else None
),
)

variable_specific_dict_of_plots[name] = go.Bar(
Expand Down
8 changes: 5 additions & 3 deletions balance/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,9 +1181,11 @@ def _return_type_creation_function(x: Any) -> Union[Callable, Any]:

# Reapply the index for pd.Series
r = [
pd.Series(data, index=pd.Series(orig_data)[nonmissing_mask].index)
if isinstance(orig_data, pd.Series)
else data
(
pd.Series(data, index=pd.Series(orig_data)[nonmissing_mask].index)
if isinstance(orig_data, pd.Series)
else data
)
for data, orig_data in zip(r, args)
]

Expand Down
17 changes: 7 additions & 10 deletions balance/weighting_methods/cbps.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,7 @@ def _reverse_svd_and_centralization(beta, U, s, Vh, X_matrix_mean, X_matrix_std)
beta_new = np.insert(
beta_new,
0,
beta[
0,
]
- np.matmul(X_matrix_mean, beta_new),
beta[0,] - np.matmul(X_matrix_mean, beta_new),
)
return beta_new

Expand Down Expand Up @@ -715,14 +712,14 @@ def cbps( # noqa
# The following are the results of the optimizations
"rescale_initial_result": rescale_initial_result,
"balance_optimize_result": balance_optimize_result,
"gmm_optimize_result_glm_init": gmm_optimize_result_glm_init
if cbps_method == "over"
else None,
"gmm_optimize_result_glm_init": (
gmm_optimize_result_glm_init if cbps_method == "over" else None
),
# pyre-fixme[61]: `gmm_optimize_result_bal_init` is undefined, or not
# always defined.
"gmm_optimize_result_bal_init": gmm_optimize_result_bal_init
if cbps_method == "over"
else None,
"gmm_optimize_result_bal_init": (
gmm_optimize_result_bal_init if cbps_method == "over" else None
),
},
}
logger.info("Done cbps function")
Expand Down
4 changes: 1 addition & 3 deletions balance/weighting_methods/ipw.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,7 @@ def ipw(
raise NotImplementedError()
logger.debug(f"fit['lambda_1se']: {fit['lambda_1se']}")

X_matrix_sample = X_matrix[
:sample_n,
].toarray()
X_matrix_sample = X_matrix[:sample_n,].toarray()

logger.info(f"max_de: {max_de}")
if max_de is not None:
Expand Down

0 comments on commit 592483a

Please sign in to comment.