Skip to content

Commit

Permalink
Make the pairwise function more convenient
Browse files Browse the repository at this point in the history
  • Loading branch information
dustalov committed Jul 9, 2024
1 parent 8c26b25 commit ab6eb81
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
7 changes: 4 additions & 3 deletions Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@
"metadata": {},
"outputs": [],
"source": [
"bt_pairwise = evalica.pairwise(bt_result.scores)\n",
"\n",
"df_bt_pairwise = pd.DataFrame(bt_pairwise, index=bt_result.scores.index, columns=bt_result.scores.index)\n",
"df_bt_pairwise = evalica.pairwise_frame(bt_result.scores.sort_values(ascending=False))\n",
"\n",
"df_bt_pairwise"
]
Expand All @@ -149,8 +147,11 @@
"source": [
"def visualize(df_pairwise: pd.DataFrame) -> Figure:\n",
" fig = px.imshow(df_pairwise, color_continuous_scale=\"RdBu\", text_auto=\".2f\")\n",
"\n",
" fig.update_layout(xaxis_title=\"Loser\", yaxis_title=\"Winner\", xaxis_side=\"top\")\n",
"\n",
" fig.update_traces(hovertemplate=\"Winner: %{y}<br>Loser: %{x}<br>Fraction of Wins: %{z}<extra></extra>\")\n",
"\n",
" return fig"
]
},
Expand Down
10 changes: 5 additions & 5 deletions python/evalica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,12 @@ def pagerank(
)


def _pairwise_ndarray(scores: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
def pairwise_scores(scores: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
return scores[:, np.newaxis] / (scores + scores[:, np.newaxis])


def pairwise(scores: pd.Series[T]) -> npt.NDArray[np.float64]: # type: ignore[type-var]
scores = scores.sort_values(ascending=False)
return _pairwise_ndarray(scores.to_numpy())
def pairwise_frame(scores: pd.Series[T]) -> pd.DataFrame: # type: ignore[type-var]
return pd.DataFrame(pairwise_scores(scores.to_numpy()), index=scores.index, columns=scores.index)


__all__ = [
Expand All @@ -340,5 +339,6 @@ def pairwise(scores: pd.Series[T]) -> npt.NDArray[np.float64]: # type: ignore[t
"matrices",
"newman",
"pagerank",
"pairwise",
"pairwise_scores",
"pairwise_frame",
]

0 comments on commit ab6eb81

Please sign in to comment.