Skip to content

Commit

Permalink
Merge pull request #12 from edahelsinki/development
Browse files Browse the repository at this point in the history
Plot improvements and interactive notebook
  • Loading branch information
Aggrathon authored Apr 21, 2023
2 parents ae72075 + c2544e5 commit a352a5f
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 12 deletions.
267 changes: 267 additions & 0 deletions examples/04_interactive_plots_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Interactive plots for Slisemap\n",
"\n",
"Since Slisemap is meant as a tool for exploration and investigation of datasets and machine learning models, some interactivity can be really benefitial.\n",
"In this notebook we explore some of the ways to make the plots more interactive.\n",
"\n",
"> NOTE: These plots will not show up in the statically rendered notebook on GitHub. You have to actually run the notebook to see the interactivity."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"from pathlib import Path\n",
"from urllib.request import urlretrieve\n",
"\n",
"sys.path.insert(0, \"..\")\n",
"\n",
"from slisemap import Slisemap"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Objectives\n",
"\n",
"These are the objectives of this notebook:\n",
"\n",
"- Demonstrate how to make interactive plots for Slisemap\n",
"- Discuss how interactive plots are benefitial"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cached results\n",
"\n",
"In this notebook we will reuse the results from a [previous notebook](01_regression_example_autompg.ipynb) (the dataset about cars and fuel efficiency):"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"SM_CACHE_PATH = Path(\"cache\") / \"01_regression_example_autompg.sm\"\n",
"\n",
"if not SM_CACHE_PATH.exists():\n",
" SM_CACHE_PATH.parent.mkdir(exist_ok=True, parents=True)\n",
" urlretrieve(\n",
" f\"https://raw.githubusercontent.com/edahelsinki/slisemap/data/examples/cache/{SM_CACHE_PATH.name}\",\n",
" SM_CACHE_PATH,\n",
" )\n",
"\n",
"sm = Slisemap.load(SM_CACHE_PATH, \"cpu\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## IPyWidgets\n",
"\n",
"An easy way to implement interactivity in any jupyter notebook is through the [IPyWidgets](https://ipywidgets.readthedocs.io) package."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from ipywidgets import interact"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The *IPyWidgets* package comes with an `interact` function/decorator that can be used to add visual controls to the normal Slisemap plots:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@interact(clusters=(2,10,1), jitter=(0, 0.1, 0.01), bars=[False, True])\n",
"def tmp(clusters=5, jitter=0, bars=True):\n",
" sm.plot(clusters=clusters, jitter=jitter, bars=bars)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The big benefit of interactive plots is that it makes it easy to flip back and forth between configuration, which makes the comparison faster.\n",
"For example, many Slisemap visualisations offer clustering to make interpretation easier, and interactive plots can be used to choose the number of clusters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@interact(jitter=(0, 0.1, 0.01), cols=(3, 6, 1))\n",
"def tmp(jitter=0, cols=4):\n",
" sm.plot_dist(scatter=True, jitter=jitter, col_wrap=cols)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Due to how Slisemap is defined, some points end up very close or even on top of each other.\n",
"One way to see the real density is to add some random noise, *jitter*, to the embedding.\n",
"With interactive plots it is easy to go between the true embedding and a (maybe) more informative embedding (with jitter)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@interact(clusters=(1,10,1), smoothing=(0.5, 1.25, 0.05))\n",
"def tmp(clusters=5, smoothing=0.75, cols=4):\n",
" sm.plot_dist(scatter=False, clusters=clusters, bw_adjust=smoothing)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Interactive plots can be used to control the level of details.\n",
"Clustering has already been mentioned, but another parameter is the smoothing in kde plots."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@interact(index=sm.metadata.get_rows(fallback=True))\n",
"def tmp(index=0):\n",
" sm.plot_position(index=index)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"However, the simplicity of *IPywidgets* makes it more useful for configuration than more complex interactions."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Slisemap interactive\n",
"\n",
"For interactive plots dedicated to Slisemap we can use the [slisemap_interactive](https://github.com/edahelsinki/slisemap_interactive) package.\n",
"In addition to controls similar to the ones above, *slisemap_intercative* also reacts to the mouse, e.g., hover over a point in the embedding to see more information about it in other plots.\n",
"Connected plots is benefitial for exploration since it is easier to select data items, and sync the selection between all plots."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from slisemap_interactive import plot\n",
"\n",
"plot(sm)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This works in both jupyter notebooks and from a normal Python REPL terminal. *slisemap_interactive* can also be used from a normal terminal without starting Python first:\n",
"\n",
"```{bash}\n",
"slisemap_interactive path/to/slisemap/object.sm\n",
"```\n",
"\n",
"Using *slisemap_interactive* like this gives you a fixed four-plot layout.\n",
"If you want more flexibility *slisemap_interactive* is also a plugin for [χiplot](https://github.com/edahelsinki/xiplot) (install both, run *χiplot*, and load a Slisemap file).\n",
"In *χiplot* you can individually add, remove, and configure the plots (including plots from both *χiplot* and *slisemap_interactive*)."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Optimisation\n",
"\n",
"With interactive controls, such as those from *IPyWidgets* we could also control the optimisation of Slisemap objects.\n",
"However, interactive updates should ideally not take more than a few seconds, which might be too short for optimisation.\n",
"One option is to pre-train the Slisemap object and then use the quicker `sm.lbfgs()` instead of a full `sm.optimise()`.\n",
"But a better alternative would be to pre-calculate all the Slisemap variants, in which case the calculations are just redrawing the plots with a different Slisemap object."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"This notebook demonstrates how to create interactive plots for Slisemap, ranging from \"configuration\" style interactivity to \"mouse-driven\" events.\n",
"The advantage of interactive plots is how easy it is to try different configurations, to find the best visualisations.\n",
"Deeper interactivity with multiple connected plots also speeds up exploration and interpretation of the results."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "slisemap",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Examples

This directory contains jupyter notebooks that demonstrate how to use SLISEMAP.
The recommended reading order is:

- [01_regression_example_autompg.ipynb](01_regression_example_autompg.ipynb) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/edahelsinki/slisemap/HEAD?labpath=examples%2F01_regression_example_autompg.ipynb)
- [02_classification_example_airquality.ipynb](02_classification_example_airquality.ipynb) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/edahelsinki/slisemap/HEAD?labpath=examples%2F02_classification_example_airquality.ipynb)
- [03_hyperparameter_tuning_example.ipynb](03_hyperparameter_tuning_example.ipynb) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/edahelsinki/slisemap/HEAD?labpath=examples%2F03_hyperparameter_tuning_example.ipynb)
- [04_interactive_plots_example.ipynb](04_interactive_plots_example.ipynb) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/edahelsinki/slisemap/HEAD?labpath=examples%2F04_interactive_plots_example.ipynb)

Additionally, the directory contains a brief tutorial on how to do optimization with Torch. This is not specific to SLISEMAP.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "slisemap"
version = "1.5.1"
version = "1.5.2"
authors = [
{ name = "Anton Björklund", email = "[email protected]" },
{ name = "Jarmo Mäkelä", email = "[email protected]" },
Expand Down
39 changes: 30 additions & 9 deletions slisemap/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def plot_embedding_facet(
names: Sequence[str],
legend_title: str = "Value",
jitter: Union[float, np.ndarray] = 0.0,
share_hue: bool = True,
**kwargs: Any,
) -> sns.FacetGrid:
"""Plot (multiple) embeddings.
Expand Down Expand Up @@ -278,15 +279,35 @@ def plot_embedding_facet(
for i, n in enumerate(names)
)
kwargs.setdefault("palette", "rocket")
kwargs.setdefault("kind", "scatter")
g = sns.relplot(
data=df,
x=dimensions[0],
y=dimensions[1],
hue=legend_title,
col="var",
**kwargs,
)
if share_hue:
kwargs.setdefault("kind", "scatter")
g = sns.relplot(
data=df,
x=dimensions[0],
y=dimensions[1],
hue=legend_title,
col="var",
**kwargs,
)
else:
fgkws = kwargs.pop("facet_kws", {})
fgkws.setdefault("height", 5)
for k in ("height", "aspect", "col_wrap"):
if k in kwargs:
fgkws[k] = kwargs.pop(k)
fgkws.setdefault("legend_out", False)
g = sns.FacetGrid(data=df, col="var", hue=legend_title, **fgkws)
for key, ax in g.axes_dict.items():
mask = df["var"] == key
df2 = {k: v[mask] for k, v in df.items()}
sns.scatterplot(
data=df2,
hue=legend_title,
x=dimensions[0],
y=dimensions[1],
ax=ax,
**kwargs,
)
g.set_titles("{col_name}")
return g

Expand Down
4 changes: 3 additions & 1 deletion slisemap/slisemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ def plot(

kwargs.setdefault("figsize", (12, 6))
fig, (ax1, ax2) = plt.subplots(1, 2, **kwargs)
if clusters is None:
if not np.iterable(clusters) and not clusters:
_assert(not bars, "`bars!=False` requires `clusters`", Slisemap.plot)
if Z.shape[0] == self._Z.shape[0]:
yhat = self.predict(numpy=False)
Expand Down Expand Up @@ -1594,8 +1594,10 @@ def plot_dist(
labels,
jitter=jitter,
col_wrap=col_wrap,
share_hue=False,
**kwargs,
)
legend_inside = False
else:
if isinstance(clusters, int):
clusters, _ = self.get_model_clusters(clusters, B)
Expand Down
1 change: 1 addition & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def test_plot():
sm.plot(title="ASD", clusters=4, show=False)
sm.plot(title="ASD", clusters=4, bars=True, show=False)
sm.plot(title="ASD", clusters=4, bars=-1, show=False)
sm.plot(title="ASD", clusters=0, show=False)
sm.plot(clusters=cl, bars=False, show=False)
sm.plot(clusters=cl, bars=True, show=False)
cl2 = np.asarray([f"A{9-i}" for i in np.unique(cl)])[cl]
Expand Down

0 comments on commit a352a5f

Please sign in to comment.