Skip to content
This repository has been archived by the owner on Aug 27, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed May 13, 2023
2 parents 4961665 + 4cccff4 commit ee9d7f3
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 207 deletions.
24 changes: 17 additions & 7 deletions ethicml/plot/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _multivariate_grid(
df: pd.DataFrame,
scatter_alpha: float = 0.5,
) -> None:
def colored_scatter(x: Any, y: Any, c: str | None = None) -> Callable[[Any], None]:
def colored_scatter(x: pd.Series, y: pd.Series, c: str | None = None) -> Callable[..., None]:
def scatter(*args: Any, **kwargs: Any) -> None:
args = (x, y)
if c is not None:
Expand All @@ -109,15 +109,25 @@ def scatter(*args: Any, **kwargs: Any) -> None:
for name, df_group in df.groupby([sens_col, outcome_col]):
legends.append(f"S={name[0]}, Y={name[1]}")
g.plot_joint(colored_scatter(df_group[col_x], df_group[col_y], color))
sns.distplot( # type: ignore[attr-defined]
df_group[col_x].to_numpy(), ax=g.ax_marg_x, color=color
sns.histplot( # type: ignore[attr-defined]
df_group[col_x].to_numpy(),
ax=g.ax_marg_x,
color=color,
kde=True,
stat="density",
kde_kws=dict(cut=3),
)
sns.distplot( # type: ignore[attr-defined]
df_group[col_y].to_numpy(), ax=g.ax_marg_y, vertical=True
sns.histplot( # type: ignore[attr-defined]
df_group[col_y].to_numpy(),
ax=g.ax_marg_y,
vertical=True,
kde=True,
stat="density",
kde_kws=dict(cut=3),
)
# Do also global Hist:
# sns.distplot(df[col_x].values, ax=g.ax_marg_x, color='grey')
# sns.distplot(df[col_y].values.ravel(), ax=g.ax_marg_y, color='grey', vertical=True)
# sns.histplot(df[col_x].values, ax=g.ax_marg_x, color='grey')
# sns.histplot(df[col_y].values.ravel(), ax=g.ax_marg_y, color='grey', vertical=True)
plt.legend(legends)


Expand Down
Loading

0 comments on commit ee9d7f3

Please sign in to comment.