From 2a017929f72dd3a2eb6c63ee6a84a6db39b1e4b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Mon, 19 Feb 2024 12:27:32 +0200 Subject: [PATCH 01/11] ruff linting --- pyproject.toml | 53 ++++++- slisemap_interactive/__init__.py | 40 +++--- slisemap_interactive/__main__.py | 3 + slisemap_interactive/app.py | 42 +++--- slisemap_interactive/layout.py | 17 +-- slisemap_interactive/load.py | 63 ++++++--- slisemap_interactive/plots.py | 236 ++++++++++++++++++++++--------- slisemap_interactive/xiplot.py | 196 ++++++++++++++++--------- tests/__init__py | 0 tests/test_load.py | 2 +- tests/test_plots.py | 20 ++- 11 files changed, 463 insertions(+), 209 deletions(-) delete mode 100644 tests/__init__py diff --git a/pyproject.toml b/pyproject.toml index 1e8628b..8966d29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ [project.optional-dependencies] xiplot = ["xiplot"] -dev = ["pytest", "black[jupyter]", "pylint", "jupyter", "IPython"] +dev = ["pytest", "pytest-cov", "black[jupyter]", "ruff", "jupyter", "IPython"] [project.urls] github = "https://github.com/edahelsinki/slisemap_interactive" @@ -50,3 +50,54 @@ slisemap_interactive = "slisemap_interactive.app:cli" [tool.setuptools] packages = ["slisemap_interactive"] + +[tool.coverage.run] +branch = true + +[tool.coverage.report] +exclude_also = [ + "_deprecated", + "print", + "plt.show", + "if verbose", + "ImportError", + "_warn", +] + +[tool.ruff.lint] +select = [ + "I", + "E", + "F", + "B", + "C4", + "W", + "D", + "UP", + "ANN", + "SIM", + "RUF", + "S", + "N", +] +ignore = [ + "E501", + "B006", + "D105", + "D203", + "D204", + "D406", + "D213", + "D407", + "D413", + "ANN101", + "ANN102", + "ANN401", + "S101", + "N802", + "N803", + "N806", +] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["D", "ANN"] diff --git a/slisemap_interactive/__init__.py b/slisemap_interactive/__init__.py index 6fc42c2..49020d5 100644 --- a/slisemap_interactive/__init__.py +++ b/slisemap_interactive/__init__.py @@ -1,25 +1,23 @@ -""" - Slisemap - Interactive: A Dash app for interactively visualising Slisemap objects - ================================================================================= +"""Slisemap - Interactive: A Dash app for interactively visualising Slisemap objects. +================================================================================= - Use the `plot` function for non-blocking interactive plots in notebooks and interactive interpreters. - In (non-Python) terminals you can call `slisemap_interactive` to start a standalone application. - Finally, this package also integrates into χiplot as a plugin. +Use the `plot` function for non-blocking interactive plots in notebooks and interactive interpreters. +In (non-Python) terminals you can call `slisemap_interactive` to start a standalone application. +Finally, this package also integrates into χiplot as a plugin. - Relevant links: - --------------- +Relevant links: +--------------- - - [GitHub repository](https://github.com/edahelsinki/slisemap_interactive) - - [Slisemap](https://github.com/edahelsinki/slisemap) - - [χiplot](https://github.com/edahelsinki/xiplot) - - [Dash](https://dash.plotly.com/) -""" +- [GitHub repository](https://github.com/edahelsinki/slisemap_interactive) +- [Slisemap](https://github.com/edahelsinki/slisemap) +- [χiplot](https://github.com/edahelsinki/xiplot) +- [Dash](https://dash.plotly.com/) +""" # noqa: D205 -from slisemap_interactive.app import plot, shutdown, BackgroundApp, ForegroundApp -from slisemap_interactive.load import load, slisemap_to_dataframe - - -def __version__(): - from importlib.metadata import version - - return version("slisemap_interactive") +from slisemap_interactive.app import ( # noqa: F401 + BackgroundApp, + ForegroundApp, + plot, + shutdown, +) +from slisemap_interactive.load import load, slisemap_to_dataframe # noqa: F401 diff --git a/slisemap_interactive/__main__.py b/slisemap_interactive/__main__.py index c6947b8..87632ba 100644 --- a/slisemap_interactive/__main__.py +++ b/slisemap_interactive/__main__.py @@ -1,4 +1,7 @@ +"""Run the slisemap_interactive app.""" + import sys + from slisemap_interactive.app import cli # This script is only called when there is a local copy of the repository, diff --git a/slisemap_interactive/app.py b/slisemap_interactive/app.py index f701e95..305abfd 100644 --- a/slisemap_interactive/app.py +++ b/slisemap_interactive/app.py @@ -1,6 +1,5 @@ -""" - Simple standalone Dash app. -""" +"""Simple standalone Dash app.""" + import argparse import os from os import PathLike @@ -17,7 +16,7 @@ # But jupyter_dash is not compatible with pyodide. JupyterDash = Dash -from slisemap_interactive.layout import register_callbacks, page_with_all_plots +from slisemap_interactive.layout import page_with_all_plots, register_callbacks from slisemap_interactive.load import ( DEFAULT_MAX_L, DEFAULT_MAX_N, @@ -28,9 +27,9 @@ from slisemap_interactive.plots import DataCache -def cli(): - """ - Plot a slisemap object interactively. +def cli() -> None: + """Plot a slisemap object interactively. + This function acts like a command line program. Arguments are parsed from `sys.argv` using `argparse.ArgumentParser()`. """ @@ -70,7 +69,7 @@ def cli(): args = parser.parse_args() path = args.PATH if os.path.isdir(path): - for path in [f for f in os.listdir(path) if f.endswith(".sm")]: + for path in [f for f in os.listdir(path) if f.endswith(".sm")]: # noqa: B020 print("Using:", path) break if args.export: @@ -97,8 +96,9 @@ def plot( mode: Literal[None, "inline", "external", "jupyterlab"] = None, appargs: Dict[str, Any] = {}, **runargs: Any, -): +) -> None: """Plot a Slisemap object interactively. + This function is designed to be called from a jupyter notebook or an interactive Python shell. This function automatically starts a server in the background. @@ -115,7 +115,7 @@ def plot( app.set_data(slisemap, max_n).display(width, height, mode) -def shutdown(): +def shutdown() -> None: """Shutdown the current background server for interactive Slisemap plots. This is a shortcut for `BackgroundApp.get_app().shutdown()`. @@ -131,7 +131,8 @@ def shutdown(): class ForegroundApp(Dash): """Create a blocking Dash app for interactive visualisations of a Slisemap object.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Create a blocking Dash app.""" super().__init__(*args, title="Interactive Slisemap", **kwargs) self.data_cache = DataCache() register_callbacks(self, self.data_cache) @@ -142,6 +143,7 @@ def set_data( max_n: int = DEFAULT_MAX_N, ) -> "ForegroundApp": """Set which data the app should show new connections. + Old data is cached so that old connections continue working. For existing connections, refresh the page to get the latest data. @@ -164,7 +166,7 @@ class BackgroundApp(JupyterDash): # Store current app for reuse as a singleton __app = None - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Create the `BackgroundApp` server, see `dash.Dash()` for arguments.""" super().__init__(*args, title="Interactive Slisemap", **kwargs) self._display_url = None @@ -179,6 +181,7 @@ def set_data( max_n: int = DEFAULT_MAX_N, ) -> "BackgroundApp": """Set which data the app should show new connections. + Old data is cached so that old connections continue working. For existing connections, refresh the page to get the latest data. @@ -216,23 +219,24 @@ def get_app( app.run(**runargs) return app - def run(self, *args, **kwargs): + def run(self, *args: Any, **kwargs: Any) -> None: """Start the server, see `dash.JupyterDash().run_server()` for arguments.""" if BackgroundApp.__app is not None: warn( "A `BackgroundApp` already exists. Use `BackgroundApp.get_app(...)` to reuse it.", Warning, + stacklevel=1, ) super().run_server(*args, **kwargs) BackgroundApp.__app = self run_server = run - def shutdown(self): - """Shutdown the server""" + def shutdown(self) -> None: + """Shutdown the server.""" old_server = [ (host, port) - for host, port in self._server_threads.keys() + for host, port in self._server_threads if port == self._display_port and host in self._display_url ] for key in old_server: @@ -245,13 +249,13 @@ def shutdown(self): if BackgroundApp.__app == self: BackgroundApp.__app = None - def _display_in_colab(self, url, port, mode, width, height): + def _display_in_colab(self, url: str, port: int, *_: Any, **__: Any) -> None: # Catch parameters to the display function for reuse later (see `BackgroundApp().display()`) self._display_url = url self._display_port = port self._display_call = super()._display_in_colab - def _display_in_jupyter(self, url, port, mode, width, height): + def _display_in_jupyter(self, url: str, port: int, *_: Any, **__: Any) -> None: # Catch parameters to the display function for reuse later (see `BackgroundApp().display()`) self._display_url = url self._display_port = port @@ -262,7 +266,7 @@ def display( width: Union[str, int] = "100%", height: Union[str, int] = 1000, mode: Literal[None, "inline", "external", "jupyterlab"] = None, - ): + ) -> None: """Display the plots. Args: diff --git a/slisemap_interactive/layout.py b/slisemap_interactive/layout.py index ad53c47..a87658d 100644 --- a/slisemap_interactive/layout.py +++ b/slisemap_interactive/layout.py @@ -1,27 +1,25 @@ -""" - Create the layout and register callbacks -""" +"""Create the layout and register callbacks.""" -from dash import Dash, html import pandas as pd +from dash import Dash, html from slisemap_interactive.plots import ( BarGroupingDropdown, ClusterDropdown, ContourCheckbox, - EmbeddingPlot, + DataCache, DensityTypeDropdown, + DistributionPlot, + EmbeddingPlot, + HoverData, JitterSlider, ModelBarPlot, ModelMatrixPlot, VariableDropdown, - HoverData, - DistributionPlot, - DataCache, ) -def register_callbacks(app: Dash, data: DataCache): +def register_callbacks(app: Dash, data: DataCache) -> None: """Register callbacks for updating the plots. Args: @@ -52,7 +50,6 @@ def page_with_all_plots(df: pd.DataFrame, data_key: int) -> html.Div: "alignItems": "center", "justifyContent": "right", "flexWrap": "wrap", - "gap": "0px", "padding": "0.4rem", "background": "#FDF3FF", "borderRadius": "4px", diff --git a/slisemap_interactive/load.py b/slisemap_interactive/load.py index 9ead2ba..51d6c8e 100644 --- a/slisemap_interactive/load.py +++ b/slisemap_interactive/load.py @@ -1,32 +1,53 @@ -""" - Load Slisemap objects and convert them into dataframes. -""" +"""Load Slisemap objects and convert them into dataframes.""" import gc +import warnings from os import PathLike from pathlib import Path -from typing import Optional, Union -import warnings +from typing import Any, List, Optional, Sequence, Union -import pandas as pd import numpy as np +import pandas as pd from sklearn.cluster import KMeans try: from slisemap import Slisemap except ImportError: warnings.warn( - "Could not import Slisemap, only limited functionality is available (no loading only plotting)" + "Could not import Slisemap, only limited functionality is available (no loading only plotting)", + stacklevel=1, ) class Slisemap: + """Placeholder Slisemap class.""" + @classmethod - def load(cls, *args, **kwargs): + def load(cls, *args: Any, **kwargs: Any) -> "Slisemap": + """Trigger the loading from the real Slisemap.""" from slisemap import Slisemap return Slisemap.load(*args, **kwargs) +try: + from slisemap import Slipmap +except ImportError: + warnings.warn( + "Could not import Slipmap, only limited functionality is available (no loading only plotting)", + stacklevel=1, + ) + + class Slipmap: + """Placeholder Slipmap class.""" + + @classmethod + def load(cls, *args: Any, **kwargs: Any) -> "Slipmap": + """Trigger the loading from the real Slipmap.""" + from slisemap import Slipmap + + return Slipmap.load(*args, **kwargs) + + # Defaults for subsampling the Slisemap object DEFAULT_MAX_N = 5000 DEFAULT_MAX_L = 250 @@ -34,6 +55,7 @@ def load(cls, *args, **kwargs): def subsample(Z: np.ndarray, n: int, clusters: Optional[int] = None) -> np.ndarray: """Get indices for subsampling. + Optionally uses k-means clustering to ensure inclusion of rarer data items. Args: @@ -81,18 +103,12 @@ def slisemap_to_dataframe( Returns: A dataframe containing data from the Slisemap object (columns: "X_*", "Y_*", "Z_*", "B_*", "Local loss", ("L_*", "Clusters *")). """ - if isinstance(path, Slisemap): - sm = path - else: - sm = Slisemap.load(path, "cpu") + sm = path if isinstance(path, Slisemap) else Slisemap.load(path, "cpu") Z = sm.get_Z(rotate=True) - if max_n > 0 and sm.n > max_n: - ss = subsample(Z, max_n) - else: - ss = ... + ss = subsample(Z, max_n) if max_n > 0 and sm.n > max_n else ... - def preface_names(names, preface): + def preface_names(names: Sequence, preface: str) -> List[str]: return [n if n[:2] == preface else preface + n for n in map(str, names)] variables = sm.metadata.get_variables(intercept=False) @@ -165,17 +181,14 @@ def preface_names(names, preface): def _extract_extension(path: Union[str, PathLike]) -> str: - if isinstance(path, str): - extension = path - else: - extension = Path(path).name + extension = path if isinstance(path, str) else Path(path).name return extension.split(".")[-1] def load( path: Union[Slisemap, pd.DataFrame, str, PathLike], extension: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> pd.DataFrame: """Load a dataframe or Slisemap object (into a dataframe). @@ -204,8 +217,11 @@ def load( return slisemap_to_dataframe(path, **kwargs) -def save_dataframe(df: pd.DataFrame, path: PathLike, extension: Optional[str] = None): +def save_dataframe( + df: pd.DataFrame, path: PathLike, extension: Optional[str] = None +) -> None: """Save dataframe to a file. + Supports csv, json, feather, and parquet. Args: @@ -237,6 +253,7 @@ def save_dataframe(df: pd.DataFrame, path: PathLike, extension: Optional[str] = def get_L_column(df: pd.DataFrame, index: Optional[int] = None) -> Optional[np.ndarray]: """Get a column of the L matrix from a `slisemap_to_dataframe`. + If `df` only contains a partial L matrix, then some values will be `np.nan`. If `df` does not contain L, then `None` is returned. diff --git a/slisemap_interactive/plots.py b/slisemap_interactive/plots.py index 0c0b186..05d8bbf 100644 --- a/slisemap_interactive/plots.py +++ b/slisemap_interactive/plots.py @@ -1,6 +1,4 @@ -""" - Functions and classes for generating dynamic plots. -""" +"""Functions and classes for generating dynamic plots.""" from typing import ( Any, @@ -32,7 +30,10 @@ PLOTLY_TEMPLATE = "slisemap_interactive" DEFAULT_TEMPLATE = "plotly_white+" + PLOTLY_TEMPLATE pio.templates[PLOTLY_TEMPLATE] = go.layout.Template( - layout=dict(margin=dict(l=10, r=10, t=30, b=20, autoexpand=True), uirevision=True) + layout={ + "margin": {"l": 10, "r": 10, "t": 30, "b": 20, "autoexpand": True}, + "uirevision": True, + } ) @@ -53,7 +54,7 @@ def try_twice(fn: Callable[[], Any], *args: Any, **kwargs: Any) -> Any: return fn(*args, **kwargs) -def nested_get(obj: Any, *keys) -> Optional[Any]: +def nested_get(obj: Any, *keys: Any) -> Optional[Any]: """Get a value from a nested object. Args: @@ -73,7 +74,7 @@ def nested_get(obj: Any, *keys) -> Optional[Any]: def first_not_none( objects: Sequence[Optional[Any]], map: Optional[Callable[[Any], Optional[Any]]] = None, - *args, + *args: Any, ) -> Optional[Any]: """Find the first value that is not `None` (with optional mapping function). @@ -115,10 +116,9 @@ def is_cluster_or_categorical(df: pd.DataFrame, column: str) -> bool: return True if "cluster" in column.lower(): return True - if is_object_dtype(col): - if len(col[:50].unique()) <= 10: - if len(col.unique()) <= 10: - return True + if is_object_dtype(col) and len(col[:50].unique()) <= 10: # noqa: SIM102 + if len(col.unique()) <= 10: + return True return False @@ -163,12 +163,13 @@ def get_variables( if loss_first: vars2 = [v for v in vars if v != "Local loss"] if len(vars) - 1 == len(vars2): - vars = ["Local loss"] + vars2 + vars = ["Local loss", *vars2] return vars def placeholder_figure(text: str) -> Dict[str, Any]: """Display a placeholder text instead of a graph. + This can be used in a "callback" function when a graph cannot be rendered. Args: @@ -224,6 +225,7 @@ class DataCache(dict): def add_data(self, df: pd.DataFrame) -> int: """Add a dataset to the cache. + This function checks for and reuses duplicate datasets. Args: @@ -243,6 +245,8 @@ def add_data(self, df: pd.DataFrame) -> int: class JitterSlider(html.Div): + """Slider for jitter.""" + def __init__( self, data: Optional[int] = None, @@ -251,8 +255,9 @@ def __init__( steps: int = 5, id: Optional[Any] = None, value: float = 0.0, - **kwargs, - ): + **kwargs: Any, + ) -> None: + """Create a jitter slider.""" values = np.linspace(0.0, scale, steps) marks = {0: "No jitter"} for v in values[1:]: @@ -265,10 +270,13 @@ def __init__( @classmethod def generate_id(cls, data: int, controls: str) -> Dict[str, Any]: + """Generate dash id.""" return {"type": cls.__name__, "data": data, "controls": controls} class VariableDropdown(dcc.Dropdown): + """Dropdown for selecting variable.""" + def __init__( self, df: pd.DataFrame, @@ -276,8 +284,9 @@ def __init__( controls: str = "default", id: Optional[Any] = None, value: Optional[str] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: + """Create variable dropdown.""" vars = get_variables(df) if value is None or value not in vars: value = vars[0] @@ -288,10 +297,13 @@ def __init__( @classmethod def generate_id(cls, data: int, controls: str) -> Dict[str, Any]: + """Generate dash id.""" return {"type": cls.__name__, "data": data, "controls": controls} class ClusterDropdown(dcc.Dropdown): + """Dropdown for selecting cluster.""" + def __init__( self, df: pd.DataFrame, @@ -299,8 +311,9 @@ def __init__( controls: str = "default", id: Optional[Any] = None, value: Optional[str] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: + """Create cluster dropdown.""" clusters = [c for c in df.columns if is_cluster_or_categorical(df, c)] if id is None: assert data is not None and controls is not None @@ -311,18 +324,22 @@ def __init__( @classmethod def generate_id(cls, data: int, controls: str) -> Dict[str, Any]: + """Generate dash id.""" return {"type": cls.__name__, "data": data, "controls": controls} class DensityTypeDropdown(dcc.Dropdown): + """Dropdown for selecting density plot type.""" + def __init__( self, data: Optional[int] = None, controls: str = "default", id: Optional[Any] = None, value: Optional[str] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: + """Create density type dropdown.""" if id is None: assert data is not None and controls is not None id = self.generate_id(data, controls) @@ -333,18 +350,22 @@ def __init__( @classmethod def generate_id(cls, data: int, controls: str) -> Dict[str, Any]: + """Generate dash id.""" return {"type": cls.__name__, "data": data, "controls": controls} class BarGroupingDropdown(dcc.Dropdown): + """Dropdown for selecting grouping.""" + def __init__( self, data: Optional[int] = None, controls: str = "default", id: Optional[Any] = None, value: Optional[str] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: + """Create grouping dropdown.""" if id is None: assert data is not None and controls is not None id = self.generate_id(data, controls) @@ -355,10 +376,13 @@ def __init__( @classmethod def generate_id(cls, data: int, controls: str) -> Dict[str, Any]: + """Generate dash id.""" return {"type": cls.__name__, "data": data, "controls": controls} class PredictionDropdown(dcc.Dropdown): + """Dropdown for selecting prediction.""" + def __init__( self, df: pd.DataFrame, @@ -366,8 +390,9 @@ def __init__( controls: str = "default", id: Optional[Any] = None, value: Optional[str] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: + """Create prediction dropdown.""" vars = [c for c in df.columns if c[0] == "Ŷ"] if len(vars) == 0: value = None @@ -382,18 +407,22 @@ def __init__( @classmethod def generate_id(cls, data: int, controls: str) -> Dict[str, Any]: + """Generate dash id.""" return {"type": cls.__name__, "data": data, "controls": controls} class ContourCheckbox(dcc.Checklist): + """Checkbox for contours.""" + def __init__( self, data: Optional[int] = None, controls: str = "default", id: Optional[Any] = None, value: bool = False, - **kwargs, - ): + **kwargs: Any, + ) -> None: + """Create contour checkbox.""" if id is None: assert data is not None and controls is not None id = self.generate_id(data, controls) @@ -401,19 +430,28 @@ def __init__( @classmethod def generate_id(cls, data: int, controls: str) -> Dict[str, Any]: + """Generate dash id.""" return {"type": cls.__name__, "data": data, "controls": controls} class EmbeddingPlot(dcc.Graph): + """Plot with 2D embedding.""" + def __init__( - self, data: int, controls: str = "default", hover: str = "default", **kwargs - ): + self, + data: int, + controls: str = "default", + hover: str = "default", + **kwargs: Any, + ) -> None: + """Create embedding plot.""" super().__init__( id=self.generate_id(data, controls, hover), clear_on_unhover=True, **kwargs ) @classmethod def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: + """Generate dash id.""" return { "type": cls.__name__, "data": data, @@ -424,10 +462,13 @@ def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: @classmethod def get_hover_index(cls, hover_data: Any) -> Optional[int]: + """Get the index of the current hover over point.""" return nested_get(hover_data, "points", 0, "customdata", 0) @classmethod - def register_callbacks(cls, app: Dash, data: DataCache): + def register_callbacks(cls, app: Dash, data: DataCache) -> None: + """Register Dash callbacks.""" + @app.callback( Output(cls.generate_id(MATCH, MATCH, MATCH), "figure"), Input(JitterSlider.generate_id(MATCH, MATCH), "value"), @@ -436,7 +477,7 @@ def register_callbacks(cls, app: Dash, data: DataCache): Input(ClusterDropdown.generate_id(MATCH, MATCH), "value"), Input(HoverData.generate_id(MATCH, MATCH), "data"), ) - def callback(jitter, variable, contour, cluster, hover): + def callback(jitter, variable, contour, cluster, hover) -> Figure: # noqa: ANN001 data_key = ctx.triggered_id["data"] df = data[data_key] dimensions = filter(lambda c: c[:2] == "Z_", df.columns) @@ -458,7 +499,9 @@ def plot( seed: int = 42, template: str = DEFAULT_TEMPLATE, ) -> Figure: - def dfmod(var): + """Create the plot.""" + + def dfmod(var: str) -> pd.DataFrame: df2 = pd.DataFrame( {x: df[x], y: df[y], var: df[var], "index": np.arange(df.shape[0])} ) @@ -544,9 +587,11 @@ def dfmod(var): trace = px.scatter(df2.iloc[[hover]], x=x, y=y).update_traces( hoverinfo="skip", hovertemplate=None, - marker=dict( - size=15, color="rgba(0,0,0,0)", line=dict(width=1, color="black") - ), + marker={ + "size": 15, + "color": "rgba(0,0,0,0)", + "line": {"width": 1, "color": "black"}, + }, ) fig.add_traces(trace.data) fig.update_yaxes(scaleanchor="x", scaleratio=1) @@ -555,15 +600,23 @@ def dfmod(var): class ModelMatrixPlot(dcc.Graph): + """Heatmap plot for the coefficient matrix.""" + def __init__( - self, data: int, controls: str = "default", hover: str = "default", **kwargs - ): + self, + data: int, + controls: str = "default", + hover: str = "default", + **kwargs: Any, + ) -> None: + """Create model matrix plot.""" super().__init__( id=self.generate_id(data, controls, hover), clear_on_unhover=True, **kwargs ) @classmethod def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: + """Generate dash id.""" return { "type": cls.__name__, "data": data, @@ -574,16 +627,19 @@ def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: @classmethod def get_hover_index(cls, hover_data: Any) -> Optional[int]: + """Get the index of the current hover over point.""" hover = nested_get(hover_data, "points", 0, "x") return int(hover) if hover is not None else None @classmethod - def register_callbacks(cls, app: Dash, data: DataCache): + def register_callbacks(cls, app: Dash, data: DataCache) -> None: + """Register Dash callbacks.""" + @app.callback( Output(cls.generate_id(MATCH, MATCH, MATCH), "figure"), Input(HoverData.generate_id(MATCH, MATCH), "data"), ) - def callback(hover): + def callback(hover: Optional[int]) -> Figure: data_key = ctx.triggered_id["data"] df = data[data_key] zs0 = next(filter(lambda c: c[:2] == "Z_", df.columns)) @@ -598,6 +654,7 @@ def plot( hover: Optional[int] = None, template: str = DEFAULT_TEMPLATE, ) -> Figure: + """Create the plot.""" if sort_by is None: order_to_sorted = np.arange(df.shape[0]) else: @@ -609,7 +666,7 @@ def plot( B_mat, color_continuous_midpoint=0, aspect="auto", - labels=dict(color="Coefficient", x="Data items sorted left to right"), + labels={"color": "Coefficient", "x": "Data items sorted left to right"}, title="Local models", color_continuous_scale="RdBu", y=coefficients, @@ -624,13 +681,23 @@ def plot( class ModelBarPlot(dcc.Graph): + """Barplot for the local model coefficients.""" + + GROUPING_OPTIONS = Literal["Variables", "Clusters"] + def __init__( - self, data: int, controls: str = "default", hover: str = "default", **kwargs - ): + self, + data: int, + controls: str = "default", + hover: str = "default", + **kwargs: Any, + ) -> None: + """Create the model bar plot.""" super().__init__(id=self.generate_id(data, controls, hover), **kwargs) @classmethod def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: + """Generate dash id.""" return { "type": cls.__name__, "data": data, @@ -639,14 +706,18 @@ def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: } @classmethod - def register_callbacks(cls, app: Dash, data: DataCache): + def register_callbacks(cls, app: Dash, data: DataCache) -> None: + """Register Dash callbacks.""" + @app.callback( Output(cls.generate_id(MATCH, MATCH, MATCH), "figure"), Input(ClusterDropdown.generate_id(MATCH, MATCH), "value"), Input(BarGroupingDropdown.generate_id(MATCH, MATCH), "value"), Input(HoverData.generate_id(MATCH, MATCH), "data"), ) - def callback(cluster, grouping, hover): + def callback( + cluster: Optional[str], grouping: cls.GROUPING_OPTIONS, hover: Optional[int] + ) -> Figure: data_key = ctx.triggered_id["data"] df = data[data_key] coefficients = [c for c in df.columns if c[:2] == "B_"] @@ -658,11 +729,9 @@ def callback(cluster, grouping, hover): Output(BarGroupingDropdown.generate_id(MATCH, MATCH), "disabled"), Input(ClusterDropdown.generate_id(MATCH, MATCH), "value"), ) - def callback_disabled(cluster): + def callback_disabled(cluster: Optional[str]) -> bool: return cluster is None - GROUPING_OPTIONS = Literal["Variables", "Clusters"] - @staticmethod def plot( df: pd.DataFrame, @@ -672,6 +741,7 @@ def plot( hover: Optional[int] = None, template: str = DEFAULT_TEMPLATE, ) -> Figure: + """Create the plot.""" coefficient_range = df[coefficients].abs().quantile(0.95).max() * 1.1 if hover is not None: fig = px.bar( @@ -689,7 +759,7 @@ def plot( .aggregate(["mean", "std"]) .stack(level=0) .reset_index() - .rename(columns=dict(level_1="Coefficients")) + .rename(columns={"level_1": "Coefficients"}) ) df2[cluster] = df2[cluster].astype("category") facet = grouping == "Clusters" @@ -725,13 +795,21 @@ def plot( class DistributionPlot(dcc.Graph): + """Distribution plot for the data and models.""" + def __init__( - self, data: int, controls: str = "default", hover: str = "default", **kwargs - ): + self, + data: int, + controls: str = "default", + hover: str = "default", + **kwargs: Any, + ) -> None: + """Create the distribution plot.""" super().__init__(id=self.generate_id(data, controls, hover), **kwargs) @classmethod def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: + """Generate dash id.""" return { "type": cls.__name__, "data": data, @@ -740,7 +818,9 @@ def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: } @classmethod - def register_callbacks(cls, app: Dash, data: DataCache): + def register_callbacks(cls, app: Dash, data: DataCache) -> None: + """Register Dash callbacks.""" + @app.callback( Output(cls.generate_id(MATCH, MATCH, MATCH), "figure"), Input(VariableDropdown.generate_id(MATCH, MATCH), "value"), @@ -748,13 +828,14 @@ def register_callbacks(cls, app: Dash, data: DataCache): Input(DensityTypeDropdown.generate_id(MATCH, MATCH), "value"), Input(HoverData.generate_id(MATCH, MATCH), "data"), ) - def callback(variable, cluster, histogram, hover): + def callback(variable, cluster, histogram, hover) -> Figure: # noqa: ANN001 data_key = ctx.triggered_id["data"] df = data[data_key] return try_twice(cls.plot, df, variable, histogram, cluster, hover) PLOT_TYPE_OPTIONS = Literal["Histogram", "Density"] + @staticmethod def plot( df: pd.DataFrame, variable: str, @@ -763,6 +844,7 @@ def plot( hover: Optional[int] = None, template: str = DEFAULT_TEMPLATE, ) -> Figure: + """Create the plot.""" if is_cluster_or_categorical(df, cluster): cats = get_categories(df[cluster]) if plot_type == "Histogram": @@ -789,7 +871,7 @@ def plot( fig = ff.create_distplot(data, clusters, show_hist=False, colors=colors) fig.update_layout( title=f"Density plot for {variable}", - legend=dict(title=cluster, traceorder="normal"), + legend={"title": cluster, "traceorder": "normal"}, ) if len(cats) < 4: fig.layout.yaxis.domain = [0.21, 1] @@ -818,13 +900,24 @@ def plot( class LinearTerms(dcc.Graph): + """Plot the local model coefficients times the variable values in a barplot. + + This plot assumes that the variables are scaled and the coefficients are linear. + """ + def __init__( - self, data: int, controls: str = "default", hover: str = "default", **kwargs - ): + self, + data: int, + controls: str = "default", + hover: str = "default", + **kwargs: Any, + ) -> None: + """Create the bar plot.""" super().__init__(id=self.generate_id(data, controls, hover), **kwargs) @classmethod def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: + """Generate dash id.""" return { "type": cls.__name__, "data": data, @@ -833,17 +926,20 @@ def generate_id(cls, data: int, controls: str, hover: str) -> Dict[str, Any]: } @classmethod - def register_callbacks(cls, app: Dash, data: DataCache): + def register_callbacks(cls, app: Dash, data: DataCache) -> None: + """Register Dash callbacks.""" + @app.callback( Output(cls.generate_id(MATCH, MATCH, MATCH), "figure"), Input(PredictionDropdown.generate_id(MATCH, MATCH), "value"), Input(HoverData.generate_id(MATCH, MATCH), "data"), ) - def callback(pred, hover): + def callback(pred: str, hover: Optional[int]) -> Figure: data_key = ctx.triggered_id["data"] df = data[data_key] return try_twice(cls.plot, df, pred, hover) + @staticmethod def plot( df: pd.DataFrame, pred: str, @@ -851,6 +947,7 @@ def plot( decimals: int = 3, template: str = DEFAULT_TEMPLATE, ) -> Figure: + """Create the plot.""" if hover is None: return no_update Xs = [c for c in df.columns if c[0] == "X"] @@ -888,19 +985,19 @@ def plot( tdec = int(np.max(np.log(np.abs(terms) + 1e-8)) // np.log(10)) tdec = decimals - min(decimals - 1, max(0, tdec)) text = [ - f"X × B = {x:.{xdec}g} × {b:.{bdec}g} = {i:.{tdec}g}" + f"X × B = {x:.{xdec}g} × {b:.{bdec}g} = {i:.{tdec}g}" # noqa: RUF001 for x, b, i in zip(xrow, brow, terms) ] xmax = np.max(np.abs(terms)) * 1.01 df2 = pd.DataFrame( - dict( - Variable=vars, - Value=xrow, - Coefficient=brow, - text=text, - sign=np.sign(terms), - Term=terms, - ) + { + "Variable": vars, + "Value": xrow, + "Coefficient": brow, + "text": text, + "sign": np.sign(terms), + "Term": terms, + } ) fig = px.bar( df2.iloc[::-1, :], @@ -928,17 +1025,22 @@ def plot( class HoverData(dcc.Store): - def __init__(self, data: int, hover: str = "default", **kwargs): + """Data store for the hover index.""" + + def __init__(self, data: int, hover: str = "default", **kwargs: Any) -> None: + """Create the data store.""" super().__init__( id=self.generate_id(data, hover), data=None, storage_type="memory", **kwargs ) @classmethod def generate_id(cls, data: int, hover: str) -> Dict[str, Any]: + """Generate dash id.""" return {"type": cls.__name__, "data": data, "hover": hover} @classmethod - def register_callbacks(cls, app: Dash, data: Optional[DataCache] = None): + def register_callbacks(cls, app: Dash, data: Optional[DataCache] = None) -> None: + """Register Dash callbacks.""" input = EmbeddingPlot.generate_id(MATCH, ALL, MATCH) input["type"] = ALL @@ -947,7 +1049,7 @@ def register_callbacks(cls, app: Dash, data: Optional[DataCache] = None): Input(input, "hoverData"), prevent_initial_call=True, ) - def hover_callback(inputs): + def hover_callback(inputs: Any) -> Optional[int]: tt = ctx.triggered_id["type"] if tt == "EmbeddingPlot": return first_not_none(inputs, EmbeddingPlot.get_hover_index) diff --git a/slisemap_interactive/xiplot.py b/slisemap_interactive/xiplot.py index 7cdd79a..65d80c3 100644 --- a/slisemap_interactive/xiplot.py +++ b/slisemap_interactive/xiplot.py @@ -1,7 +1,7 @@ -""" - Hooks for connecting to xiplot (using entry points in 'pyproject.toml'). -""" -from typing import Any, Callable, Dict, List +"""Hooks for connecting to xiplot (using entry points in 'pyproject.toml').""" + +from os import PathLike +from typing import Any, Callable, Dict, List, Optional import pandas as pd from dash import ALL, MATCH, Input, Output, State, dcc, html @@ -41,12 +41,14 @@ class LabelledControls(FlexRow): + """Wrapper that adds a label to a control.""" + def __init__( self, kwargs: Dict[str, Any] = {}, **controls: Any, - ): - """Wrap controls in a `FlexRow` with labels on top + ) -> None: + """Wrap controls in a `FlexRow` with labels on top. Args: **controls: `{label: control}`. @@ -73,27 +75,37 @@ def plugin_load() -> Dict[str, Callable[[Any], pd.DataFrame]]: """ # TODO Some columns should probably be hidden from the normal plots - def load(data, max_n: int = DEFAULT_MAX_N, max_l: int = DEFAULT_MAX_L): + def load( + data: PathLike, max_n: int = DEFAULT_MAX_N, max_l: int = DEFAULT_MAX_L + ) -> pd.DataFrame: + """Load the Slisemap.""" return slisemap_to_dataframe(data, max_n=max_n, index=False, losses=max_l) return load, ".sm" class SlisemapEmbeddingPlot(APlot): + """Embedding plot for Slisemap.""" + @classmethod def name(cls) -> str: + """Plot name.""" return "Slisemap embedding plot" @classmethod def help(cls) -> str: + """Help string.""" return ( "Plot the embedding of a Slisemap object\n\n" + 'Hover over a point when the color is based on "Local loss" to see alternative embeddings for that point.' ) @classmethod - def register_callbacks(cls, app, df_from_store, df_to_store): - PdfButton.register_callback(app, cls.get_id(None)) + def register_callbacks( + cls, app: object, df_from_store: Callable, df_to_store: Callable + ) -> None: + """Register callbacks.""" + PdfButton.register_callback(app, cls.name(), cls.get_id(None)) @app.callback( Output(cls.get_id(MATCH), "figure"), @@ -106,7 +118,7 @@ def register_callbacks(cls, app, df_from_store, df_to_store): State(ID_CLICKED, "data"), Input(ID_PLOTLY_TEMPLATE, "data"), ) - def callback(df, variable, cluster, contours, jitter, hover, click, template): + def callback(df, variable, cluster, contours, jitter, hover, click, template): # noqa: ANN001, ANN202 df = df_from_store(df) if cluster in df.columns: variable = cluster @@ -136,7 +148,7 @@ def callback(df, variable, cluster, contours, jitter, hover, click, template): Input(cls.get_id(ALL), "hoverData"), prevent_initial_call=True, ) - def hover_callback(inputs): + def hover_callback(inputs: Any) -> Optional[int]: return first_not_none(inputs, EmbeddingPlot.get_hover_index) @app.callback( @@ -145,7 +157,7 @@ def hover_callback(inputs): State(ID_CLICKED, "data"), prevent_initial_call=True, ) - def click_callback(inputs, old): + def click_callback(inputs, old) -> Optional[int]: # noqa: ANN001 new = first_not_none(inputs, EmbeddingPlot.get_hover_index) if new != old: return new @@ -154,27 +166,30 @@ def click_callback(inputs, old): PlotData.register_callback( cls.name(), app, - dict( - variable=Input(cls.get_id(MATCH, "variable"), "value"), - cluster=Input(cls.get_id(MATCH, "cluster"), "value"), - jitter=Input(cls.get_id(MATCH, "jitter"), "value"), - ), + { + "variable": Input(cls.get_id(MATCH, "variable"), "value"), + "cluster": Input(cls.get_id(MATCH, "cluster"), "value"), + "jitter": Input(cls.get_id(MATCH, "jitter"), "value"), + }, ) @classmethod - def create_layout(cls, index, df, columns, config=dict()) -> List[Any]: + def create_layout( + cls, index: object, df: pd.DataFrame, columns: Any, config: Dict[str, Any] = {} + ) -> List[object]: + """Create plot layout.""" return [ dcc.Graph(id=cls.get_id(index), clear_on_unhover=True), LabelledControls( Variable=VariableDropdown( df, id=cls.get_id(index, "variable"), - value=config.get("variable", None), + value=config.get("variable"), ), Clusters=ClusterDropdown( df, id=cls.get_id(index, "cluster"), - value=config.get("cluster", None), + value=config.get("cluster"), ), Density=ContourCheckbox( id=cls.get_id(index, "contours"), @@ -192,12 +207,16 @@ def create_layout(cls, index, df, columns, config=dict()) -> List[Any]: class SlisemapModelBarPlot(APlot): + """Bar plot for Slisemap models.""" + @classmethod def name(cls) -> str: + """Plot name.""" return "Slisemap barplot for local models" @classmethod def help(cls) -> str: + """Help string.""" return ( "Local models from a Slisemap object in a bar plot\n\n" + "The coefficients from the local models are plotted in a bar plot. " @@ -206,8 +225,11 @@ def help(cls) -> str: ) @classmethod - def register_callbacks(cls, app, df_from_store, df_to_store): - PdfButton.register_callback(app, cls.get_id(None)) + def register_callbacks( + cls, app: object, df_from_store: Callable, df_to_store: Callable + ) -> None: + """Register callbacks.""" + PdfButton.register_callback(app, cls.name(), cls.get_id(None)) @app.callback( Output(cls.get_id(MATCH), "figure"), @@ -218,7 +240,7 @@ def register_callbacks(cls, app, df_from_store, df_to_store): State(ID_CLICKED, "data"), Input(ID_PLOTLY_TEMPLATE, "data"), ) - def callback(df, clusters, grouping, hover, click, template): + def callback(df, clusters, grouping, hover, click, template): # noqa: ANN001, ANN202 df = df_from_store(df) bs = [c for c in df.columns if c[:2] == "B_"] if len(bs) == 0: @@ -240,51 +262,61 @@ def callback(df, clusters, grouping, hover, click, template): Input(cls.get_id(MATCH, "cluster"), "value"), prevent_initial_call=False, ) - def callback_disabled(cluster): + def callback_disabled(cluster: Optional[str]) -> bool: return cluster is None PlotData.register_callback( cls.name(), app, - dict( - cluster=Input(cls.get_id(MATCH, "cluster"), "value"), - grouping=Input(cls.get_id(MATCH, "grouping"), "value"), - ), + { + "cluster": Input(cls.get_id(MATCH, "cluster"), "value"), + "grouping": Input(cls.get_id(MATCH, "grouping"), "value"), + }, ) @classmethod - def create_layout(cls, index, df, columns, config=dict()) -> List[Any]: + def create_layout( + cls, index: object, df: pd.DataFrame, columns: Any, config: Dict[str, Any] = {} + ) -> List[object]: + """Create plot layout.""" return [ dcc.Graph(cls.get_id(index)), LabelledControls( Clusters=ClusterDropdown( df, id=cls.get_id(index, "cluster"), - value=config.get("cluster", None), + value=config.get("cluster"), ), Grouping=BarGroupingDropdown( id=cls.get_id(index, "grouping"), - value=config.get("grouping", None), + value=config.get("grouping"), ), ), ] class SlisemapModelMatrixPlot(APlot): + """Heatmap for Slisemap local models.""" + @classmethod def name(cls) -> str: + """Plot name.""" return "Slisemap matrixplot for local models" @classmethod def help(cls) -> str: + """Help string.""" return ( "Local models from a Slisemap object in a matrix plot\n\n" + "Hover over a column to see information about that point in other plots." ) @classmethod - def register_callbacks(cls, app, df_from_store, df_to_store): - PdfButton.register_callback(app, cls.get_id(None)) + def register_callbacks( + cls, app: object, df_from_store: Callable, df_to_store: Callable + ) -> None: + """Register callbacks.""" + PdfButton.register_callback(app, cls.name(), cls.get_id(None)) @app.callback( Output(cls.get_id(MATCH), "figure"), @@ -293,7 +325,7 @@ def register_callbacks(cls, app, df_from_store, df_to_store): State(ID_CLICKED, "data"), Input(ID_PLOTLY_TEMPLATE, "data"), ) - def callback(df, hover, click, template): + def callback(df, hover, click, template): # noqa: ANN001, ANN202 df = df_from_store(df) try: zs0 = next(filter(lambda c: c[:2] == "Z_", df.columns)) @@ -313,7 +345,7 @@ def callback(df, hover, click, template): Input(cls.get_id(ALL), "hoverData"), prevent_initial_call=True, ) - def hover_callback(inputs): + def hover_callback(inputs: Any) -> Optional[int]: return first_not_none(inputs, ModelMatrixPlot.get_hover_index) @app.callback( @@ -322,32 +354,42 @@ def hover_callback(inputs): State(ID_CLICKED, "data"), prevent_initial_call=True, ) - def click_callback(inputs, old): + def click_callback(inputs: Any, old: Any) -> Optional[int]: new = first_not_none(inputs, ModelMatrixPlot.get_hover_index) if new != old: return new return None @classmethod - def create_layout(cls, index, df, columns, config=dict()) -> List[Any]: + def create_layout( + cls, index: object, df: pd.DataFrame, columns: Any, config: Dict[str, Any] = {} + ) -> List[object]: + """Create plot layout.""" return [dcc.Graph(id=cls.get_id(index), clear_on_unhover=True)] class SlisemapDensityPlot(APlot): + """Density plot for Slisemap.""" + @classmethod def name(cls) -> str: + """Plot name.""" return "Slisemap density plot" # @classmethod # def help(cls) -> str: + # """Help string.""" # return ( # "Density plot for Slisemap objects\n\n" # + "Use clustering to easily compare the distribution of the values between different clusters." # ) @classmethod - def register_callbacks(cls, app, df_from_store, df_to_store): - PdfButton.register_callback(app, cls.get_id(None)) + def register_callbacks( + cls, app: object, df_from_store: Callable, df_to_store: Callable + ) -> None: + """Register callbacks.""" + PdfButton.register_callback(app, cls.name(), cls.get_id(None)) @app.callback( Output(cls.get_id(MATCH), "figure"), @@ -358,7 +400,7 @@ def register_callbacks(cls, app, df_from_store, df_to_store): State(ID_CLICKED, "data"), Input(ID_PLOTLY_TEMPLATE, "data"), ) - def callback(df, variable, cluster, hover, click, template): + def callback(df, variable, cluster, hover, click, template): # noqa: ANN001, ANN202 df = df_from_store(df) if variable not in df.columns: return placeholder_figure(f"{variable} not found") @@ -376,43 +418,53 @@ def callback(df, variable, cluster, hover, click, template): PlotData.register_callback( cls.name(), app, - dict( - variable=Input(cls.get_id(MATCH, "variable"), "value"), - cluster=Input(cls.get_id(MATCH, "cluster"), "value"), - ), + { + "variable": Input(cls.get_id(MATCH, "variable"), "value"), + "cluster": Input(cls.get_id(MATCH, "cluster"), "value"), + }, ) @classmethod - def create_layout(cls, index, df, columns, config=dict()): + def create_layout( + cls, index: object, df: pd.DataFrame, columns: Any, config: Dict[str, Any] = {} + ) -> List[object]: + """Create plot layout.""" return [ dcc.Graph(cls.get_id(index)), LabelledControls( Variable=VariableDropdown( df, id=cls.get_id(index, "variable"), - value=config.get("variable", None), + value=config.get("variable"), ), Clusters=ClusterDropdown( df, id=cls.get_id(index, "cluster"), - value=config.get("cluster", None), + value=config.get("cluster"), ), ), ] class SlisemapHistogramPlot(APlot): + """Histogram plot for Slisemap.""" + @classmethod def name(cls) -> str: + """Plot name.""" return "Slisemap histogram plot" # @classmethod # def help(cls) -> str: + """Help string.""" # return "Histogram for Slisemap objects" @classmethod - def register_callbacks(cls, app, df_from_store, df_to_store): - PdfButton.register_callback(app, cls.get_id(None)) + def register_callbacks( + cls, app: object, df_from_store: Callable, df_to_store: Callable + ) -> None: + """Register callbacks.""" + PdfButton.register_callback(app, cls.name(), cls.get_id(None)) @app.callback( Output(cls.get_id(MATCH), "figure"), @@ -423,7 +475,7 @@ def register_callbacks(cls, app, df_from_store, df_to_store): State(ID_CLICKED, "data"), Input(ID_PLOTLY_TEMPLATE, "data"), ) - def callback(df, variable, cluster, hover, click, template): + def callback(df, variable, cluster, hover, click, template): # noqa: ANN001, ANN202 df = df_from_store(df) if variable not in df.columns: return placeholder_figure(f"{variable} not found") @@ -442,38 +494,48 @@ def callback(df, variable, cluster, hover, click, template): PlotData.register_callback( cls.name(), app, - dict( - variable=Input(cls.get_id(MATCH, "variable"), "value"), - cluster=Input(cls.get_id(MATCH, "cluster"), "value"), - ), + { + "variable": Input(cls.get_id(MATCH, "variable"), "value"), + "cluster": Input(cls.get_id(MATCH, "cluster"), "value"), + }, ) @classmethod - def create_layout(cls, index, df, columns, config=dict()): + def create_layout( + cls, index: object, df: pd.DataFrame, columns: Any, config: Dict[str, Any] = {} + ) -> List[object]: + """Create plot layout.""" return [ dcc.Graph(cls.get_id(index)), LabelledControls( Variable=VariableDropdown( df, id=cls.get_id(index, "variable"), - value=config.get("variable", None), + value=config.get("variable"), ), Clusters=ClusterDropdown( df, id=cls.get_id(index, "cluster"), - value=config.get("cluster", None), + value=config.get("cluster"), ), ), ] class SlisemapLinearTermsPlot(APlot): + """Linear terms plot for Slisemap. + + This plot assumes that the variables have not been unscaled and that the coefficients are linear. + """ + @classmethod def name(cls) -> str: + """Plot name.""" return "Slisemap linear terms plot" @classmethod def help(cls) -> str: + """Help string.""" return ( "Linear terms plot for Slisemap objects\n\n" + 'Plot the "terms" of the linear models (variables times coefficients).' @@ -482,8 +544,11 @@ def help(cls) -> str: ) @classmethod - def register_callbacks(cls, app, df_from_store, df_to_store): - PdfButton.register_callback(app, cls.get_id(None)) + def register_callbacks( + cls, app: object, df_from_store: Callable, df_to_store: Callable + ) -> None: + """Register callbacks.""" + PdfButton.register_callback(app, cls.name(), cls.get_id(None)) @app.callback( Output(cls.get_id(MATCH), "figure"), @@ -493,7 +558,7 @@ def register_callbacks(cls, app, df_from_store, df_to_store): State(ID_CLICKED, "data"), Input(ID_PLOTLY_TEMPLATE, "data"), ) - def callback(df, pred, hover, click, template): + def callback(df, pred, hover, click, template): # noqa: ANN001, ANN202 df = df_from_store(df) if pred is None: return placeholder_figure("Could not find prediction") @@ -510,20 +575,21 @@ def callback(df, pred, hover, click, template): ) PlotData.register_callback( - cls.name(), app, dict(pred=Input(cls.get_id(MATCH, "pred"), "value")) + cls.name(), app, {"pred": Input(cls.get_id(MATCH, "pred"), "value")} ) @classmethod - def create_layout(cls, index, df, columns, config=dict()): + def create_layout( + cls, index: object, df: pd.DataFrame, columns: Any, config: Dict[str, Any] = {} + ) -> List[object]: + """Create plot layout.""" return [ dcc.Graph( cls.get_id(index), figure=placeholder_figure("Select an item to show") ), LabelledControls( Prediction=PredictionDropdown( - df, - id=cls.get_id(index, "pred"), - value=config.get("pred", None), + df, id=cls.get_id(index, "pred"), value=config.get("pred") ) ), ] diff --git a/tests/__init__py b/tests/__init__py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_load.py b/tests/test_load.py index 8bd1f02..271f21b 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,6 +1,6 @@ -from slisemap import Slisemap import numpy as np import pytest +from slisemap import Slisemap from slisemap_interactive.load import ( get_L_column, diff --git a/tests/test_plots.py b/tests/test_plots.py index dd8d5c5..295d812 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,6 +1,22 @@ -import pytest import pandas as pd -from slisemap_interactive.plots import * +import pytest + +from slisemap_interactive.plots import ( + BarGroupingDropdown, + ClusterDropdown, + ContourCheckbox, + DataCache, + DensityTypeDropdown, + DistributionPlot, + EmbeddingPlot, + HoverData, + JitterSlider, + ModelBarPlot, + ModelMatrixPlot, + VariableDropdown, + first_not_none, + nested_get, +) @pytest.fixture(scope="session") From e28a96a55f549aa7865b80f42123d8e83fc5e1e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Mon, 19 Feb 2024 15:23:52 +0200 Subject: [PATCH 02/11] test function signatures match xiplot --- slisemap_interactive/plots.py | 2 +- slisemap_interactive/xiplot.py | 11 ++- tests/test_layout.py | 7 ++ tests/test_xiplot.py | 125 +++++++++++++++++++++++++++++++++ 4 files changed, 138 insertions(+), 7 deletions(-) create mode 100644 tests/test_layout.py create mode 100644 tests/test_xiplot.py diff --git a/slisemap_interactive/plots.py b/slisemap_interactive/plots.py index 05d8bbf..1a90134 100644 --- a/slisemap_interactive/plots.py +++ b/slisemap_interactive/plots.py @@ -289,7 +289,7 @@ def __init__( """Create variable dropdown.""" vars = get_variables(df) if value is None or value not in vars: - value = vars[0] + value = vars[0] if len(vars) > 0 else None if id is None: assert data is not None and controls is not None id = self.generate_id(data, controls) diff --git a/slisemap_interactive/xiplot.py b/slisemap_interactive/xiplot.py index 65d80c3..f5e0624 100644 --- a/slisemap_interactive/xiplot.py +++ b/slisemap_interactive/xiplot.py @@ -1,7 +1,6 @@ """Hooks for connecting to xiplot (using entry points in 'pyproject.toml').""" -from os import PathLike -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import pandas as pd from dash import ALL, MATCH, Input, Output, State, dcc, html @@ -41,7 +40,7 @@ class LabelledControls(FlexRow): - """Wrapper that adds a label to a control.""" + """FlexRow wrapper that adds a labels to controls.""" def __init__( self, @@ -51,8 +50,8 @@ def __init__( """Wrap controls in a `FlexRow` with labels on top. Args: - **controls: `{label: control}`. kwargs: Additional key word arguments forwarded to `FlexRow` + **controls: `{label: control}`. """ children = [ html.Div( @@ -66,7 +65,7 @@ def __init__( super().__init__(*children, **kwargs) -def plugin_load() -> Dict[str, Callable[[Any], pd.DataFrame]]: +def plugin_load() -> Tuple[Callable[[object], pd.DataFrame], str]: """Xiplot plugin for reading Slisemap files. Returns: @@ -76,7 +75,7 @@ def plugin_load() -> Dict[str, Callable[[Any], pd.DataFrame]]: # TODO Some columns should probably be hidden from the normal plots def load( - data: PathLike, max_n: int = DEFAULT_MAX_N, max_l: int = DEFAULT_MAX_L + data: object, max_n: int = DEFAULT_MAX_N, max_l: int = DEFAULT_MAX_L ) -> pd.DataFrame: """Load the Slisemap.""" return slisemap_to_dataframe(data, max_n=max_n, index=False, losses=max_l) diff --git a/tests/test_layout.py b/tests/test_layout.py new file mode 100644 index 0000000..fac63d6 --- /dev/null +++ b/tests/test_layout.py @@ -0,0 +1,7 @@ +import pandas as pd + +from slisemap_interactive.layout import page_with_all_plots + + +def test_layout(): + page_with_all_plots(pd.DataFrame(), 0) diff --git a/tests/test_xiplot.py b/tests/test_xiplot.py new file mode 100644 index 0000000..c1d3b48 --- /dev/null +++ b/tests/test_xiplot.py @@ -0,0 +1,125 @@ +import inspect +from io import BytesIO +from typing import Any, Callable, Dict, Union, get_args, get_origin + +import numpy as np +import pandas as pd +from slisemap import Slisemap +from xiplot.plugin import APlot, AReadPlugin + +from slisemap_interactive.plots import JitterSlider +from slisemap_interactive.xiplot import ( + LabelledControls, + SlisemapDensityPlot, + SlisemapEmbeddingPlot, + SlisemapHistogramPlot, + SlisemapLinearTermsPlot, + SlisemapModelBarPlot, + SlisemapModelMatrixPlot, + plugin_load, +) + + +def unoptional(s: object) -> object: + """Unwrap Optional signature.""" + is_optional = ( + get_origin(s) == Union + and len(get_args(s)) == 2 + and get_args(s)[1] == type(None) + ) + return get_args(s)[0] if is_optional else s + + +def check_annotation(ann1: Dict[str, object], ann2: Dict[str, object]) -> bool: + """Naive check that annotations match.""" + ann1 = unoptional(ann1) + ann2 = unoptional(ann2) + if ann1 == ann2: + return True + if ann1 in [Any, object] or ann2 in [Any, object]: + return True + if get_origin(ann1) == get_origin(ann2): + return all( + check_annotation(s1i, s2i) + for s1i, s2i in zip(get_args(ann1), get_args(ann2)) + ) + return False + + +def assert_annotation_match(f1: Callable, ann2: Dict[str, object]): + """Naive assert that the annotation of function `f1` matches the annotation `ann2`.""" + s1 = f1.__annotations__ + s1.setdefault("return", None) + ann2.setdefault("return", None) + assert s1.keys() == ann2.keys(), f"{f1.__qualname__}: {s1} != {ann2}" + for i1, i2 in zip(s1.items(), ann2.items()): + assert check_annotation(i1, i2) + + +def signature_to_annotation(f: Callable) -> Dict[str, object]: + """Convert a signature into an annotation-like dictionary.""" + fsign = inspect.signature(f) + ann = {k: v.annotation for k, v in fsign.parameters.items()} + ann.setdefault("return", fsign.return_annotation) + return ann + + +def type_to_annotation(typ: type, reference: Callable): + """Convert a typing hint into an annotation-like dictionary.""" + fsign = inspect.signature(reference) + assert len(fsign.parameters) == len(get_args(typ)[0]) + ann = dict(zip(fsign.parameters, get_args(typ)[0])) + ann.setdefault("return", get_args(typ)[1]) + return ann + + +def test_load_signature(): + assert_annotation_match(plugin_load, type_to_annotation(AReadPlugin, plugin_load)) + assert_annotation_match(plugin_load, signature_to_annotation(plugin_load)) + + +def test_load(): + X = np.random.normal(0, 1, (10, 3)) + Y = np.random.normal(0, 1, 10) + B0 = np.random.normal(0, 1, (10, 4)) + sm = Slisemap(X, Y, lasso=0.1, B0=B0) + with BytesIO() as io: + sm.save(io) + io.seek(0) + sm2 = plugin_load()[0](io) + assert sm2.shape == (10, 30) + + +def test_plot_signature(): + plots = [ + SlisemapDensityPlot, + SlisemapEmbeddingPlot, + SlisemapHistogramPlot, + SlisemapLinearTermsPlot, + SlisemapModelBarPlot, + SlisemapModelMatrixPlot, + ] + for plot in plots: + assert issubclass(plot, APlot) + for name in plot.__dict__: + if name[0] != "_": + method = getattr(plot, name) + assert_annotation_match(method, getattr(APlot, name).__annotations__) + assert_annotation_match(method, signature_to_annotation(method)) + + +def test_plots(): + plots = [ + SlisemapDensityPlot, + SlisemapEmbeddingPlot, + SlisemapHistogramPlot, + SlisemapLinearTermsPlot, + SlisemapModelBarPlot, + SlisemapModelMatrixPlot, + ] + for plot in plots: + plot.create_layout(0, pd.DataFrame(), None, {}) + + +def test_labelled_controls(): + LabelledControls({"id": "test"}, test=JitterSlider(id="jitter")) From a87c9dcc93277ea3ca3d3fb8adfd81e2046ca547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Mon, 19 Feb 2024 19:43:29 +0200 Subject: [PATCH 03/11] support slipmap --- .gitignore | 2 +- pyproject.toml | 5 +- slisemap_interactive/__init__.py | 6 +- slisemap_interactive/load.py | 117 +++++++++++++++++++++++++++++-- slisemap_interactive/plots.py | 48 +++++++++---- slisemap_interactive/xiplot.py | 19 ++++- tests/test_load.py | 12 +++- tests/test_plots.py | 19 ++--- tests/test_xiplot.py | 18 +++-- 9 files changed, 210 insertions(+), 36 deletions(-) diff --git a/.gitignore b/.gitignore index 18051da..81e30e1 100644 --- a/.gitignore +++ b/.gitignore @@ -169,4 +169,4 @@ poetry.toml *.sm *.tmp *.txt -!*requitements*.txt \ No newline at end of file +*.sp \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 8966d29..e6518c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "slisemap_interactive" -version = "0.5.0" +version = "0.6.0" description = "Interactive plots for Slisemap using Dash" readme = "README.md" license = { text = "MIT" } @@ -35,7 +35,8 @@ dev = ["pytest", "pytest-cov", "black[jupyter]", "ruff", "jupyter", "IPython"] github = "https://github.com/edahelsinki/slisemap_interactive" [project.entry-points."xiplot.plugin.read"] -xiplot_plugin_load = "slisemap_interactive.xiplot:plugin_load" +xiplot_plugin_load_slisemap = "slisemap_interactive.xiplot:load_slisemap" +xiplot_plugin_load_slipmap = "slisemap_interactive.xiplot:load_slipmap" [project.entry-points."xiplot.plugin.plot"] xiplot_plugin_embeddingplot = "slisemap_interactive.xiplot:SlisemapEmbeddingPlot" diff --git a/slisemap_interactive/__init__.py b/slisemap_interactive/__init__.py index 49020d5..446587a 100644 --- a/slisemap_interactive/__init__.py +++ b/slisemap_interactive/__init__.py @@ -20,4 +20,8 @@ plot, shutdown, ) -from slisemap_interactive.load import load, slisemap_to_dataframe # noqa: F401 +from slisemap_interactive.load import ( # noqa: F401 + load, + slipmap_to_dataframe, + slisemap_to_dataframe, +) diff --git a/slisemap_interactive/load.py b/slisemap_interactive/load.py index 51d6c8e..c093c2f 100644 --- a/slisemap_interactive/load.py +++ b/slisemap_interactive/load.py @@ -51,6 +51,8 @@ def load(cls, *args: Any, **kwargs: Any) -> "Slipmap": # Defaults for subsampling the Slisemap object DEFAULT_MAX_N = 5000 DEFAULT_MAX_L = 250 +INDEX_COLUMN = "item" +PROTOTYPE_COLUMN = "Slipmap Prototype" def subsample(Z: np.ndarray, n: int, clusters: Optional[int] = None) -> np.ndarray: @@ -94,7 +96,7 @@ def slisemap_to_dataframe( """Convert a `Slisemap` object to a `pandas.DataFrame`. Args: - path: Slisemap object or path to a saved slisemap object. + path: Slisemap object or path to a saved Slisemap object. losses: Return the loss matrix. Can also be a number specifying the (approximate) maximum number of `L_*` columns. Default to True. clusters: Return cluster indices (if greater than one). Defaults to 9. max_n: maximum number of data items in the dataframe (subsampling is recommended if `n > 5000` and `losses=True`). Defaults to -1. @@ -174,12 +176,113 @@ def preface_names(names: Sequence, preface: str) -> List[str]: if index: df.index = rows else: - df.insert(0, "item", rows) + df.insert(0, INDEX_COLUMN, rows) del dfs gc.collect(1) return df +def slipmap_to_dataframe( + path: Union[str, PathLike, Slipmap], + losses: bool = True, + clusters: int = 9, + max_n: int = -1, + index: bool = True, +) -> pd.DataFrame: + """Convert a `Slipmap` object to a `pandas.DataFrame`. + + Args: + path: Slipmap object or path to a saved Slipmap object. + losses: Return the loss matrix. Can also be a number specifying the (approximate) maximum number of `L_*` columns. Default to True. + clusters: Return cluster indices (if greater than one). Defaults to 9. + max_n: maximum number of data items in the dataframe (subsampling is recommended if `n > 5000` and `losses=True`). Defaults to -1. + index: Return row names as the index (True) or as an "item" column (False). Defaults to True. + + Returns: + A dataframe containing data from the Slipmap object (columns: "X_*", "Y_*", "Z_*", "B_*", "Local loss", ("L_*", "Clusters *")). + """ + sp = path if isinstance(path, Slipmap) else Slipmap.load(path, "cpu") + + Z = sp.get_Z() + ss = subsample(Z, max_n) if max_n > 0 and sp.n > max_n else ... + + def preface_names(names: Sequence, preface: str) -> List[str]: + return [n if n[:2] == preface else preface + n for n in map(str, names)] + + variables = sp.metadata.get_variables(intercept=False) + variables = preface_names(variables, "X_") + targets = sp.metadata.get_targets() + if len(targets) > 1 or targets[0] != "Y": + targets = preface_names(targets, "Y_") + predictions = ["Ŷ" + t[1:] for t in targets] + coefficients = sp.metadata.get_coefficients() + coefficients = preface_names(coefficients, "B_") + dimensions = sp.metadata.get_dimensions() + dimensions = preface_names(dimensions, "Z_") + rows = sp.metadata.get_rows(fallback=False) + rows_proto = range(sp.n, sp.n + sp.p) + has_index = True + if ss is not ...: + rows = ss if rows is None else np.asarray(rows)[ss] + elif rows is None: + has_index = False + rows = range(sp.n) + + pred = sp.predict(sp._X[ss, :]) + local_loss = sp.local_loss(sp._Y, sp._as_new_Y(pred)).detach().cpu().numpy() + dfs = [ + pd.DataFrame.from_records(sp.metadata.unscale_X()[ss, :], columns=variables), + pd.DataFrame.from_records(sp.metadata.unscale_Y()[ss, :], columns=targets), + pd.DataFrame.from_records(Z[ss, :], columns=dimensions), + pd.DataFrame.from_records(sp.get_B()[ss, :], columns=coefficients), + pd.DataFrame.from_records(sp.metadata.unscale_Y(pred), columns=predictions), + pd.DataFrame({"Local loss": local_loss}), + ] + del variables, targets, predictions, Z, pred, local_loss + gc.collect(1) + + dfs2 = [ + pd.DataFrame.from_records(sp.get_Zp(), columns=dimensions), + pd.DataFrame.from_records(sp.get_Bp(), columns=coefficients), + ] + if losses: + L = sp.get_L(X=sp._X[ss, :], Y=sp._Y[ss, :])[ss, :] + Ln = [f"LT_{i}" for i in rows_proto] + dfs.append(pd.DataFrame.from_records(L.T, columns=Ln)) + del L + + if clusters > 1: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + clusters = { + f"Clusters {i}": pd.Series( + sp.get_model_clusters(i)[0][ss], dtype="category" + ) + for i in range(2, clusters + 1) + } + dfs.append(pd.DataFrame(clusters)) + del sp, dimensions, coefficients + gc.collect(1) + + # Then we create a dataframe to return + df1 = pd.concat(dfs, axis=1, copy=False) + df2 = pd.concat(dfs2, axis=1, copy=False) + df1[PROTOTYPE_COLUMN] = False + df2[PROTOTYPE_COLUMN] = True + df2.index = rows_proto + del dfs, dfs2 + if has_index: + if index: + df1.index = rows + else: + df1.insert(0, INDEX_COLUMN, rows) + df2.insert(0, INDEX_COLUMN, rows_proto) + df = pd.concat((df1, df2), axis=0, copy=False) + del df1, df2 + gc.collect(1) + return df + + def _extract_extension(path: Union[str, PathLike]) -> str: extension = path if isinstance(path, str) else Path(path).name return extension.split(".")[-1] @@ -204,6 +307,8 @@ def load( return path if isinstance(path, Slisemap): return slisemap_to_dataframe(path, **kwargs) + if isinstance(path, Slipmap): + return slipmap_to_dataframe(path, **kwargs) if extension is None: extension = _extract_extension(path) if extension == "csv": @@ -214,6 +319,8 @@ def load( return pd.read_json(path) if extension == "feather" or extension == "ft": return pd.read_feather(path) + if extension == "sp": + return slipmap_to_dataframe(path, **kwargs) return slisemap_to_dataframe(path, **kwargs) @@ -234,11 +341,11 @@ def save_dataframe( """ if extension is None: extension = _extract_extension(path) - if "item" not in df.columns and not ( + if INDEX_COLUMN not in df.columns and not ( isinstance(df.index, pd.RangeIndex) and df.index.identical(pd.RangeIndex.from_range(range(df.shape[0]))) ): - df = df.reset_index().rename(columns={"index": "item"}) + df = df.reset_index().rename(columns={"index": INDEX_COLUMN}) if extension == "csv": df.to_csv(path, index=False) elif extension == "json": @@ -264,7 +371,7 @@ def get_L_column(df: pd.DataFrame, index: Optional[int] = None) -> Optional[np.n Returns: Loss column. """ - rows = df.get("item", df.index) + rows = df.get(INDEX_COLUMN, df.index) col = df.get(f"L_{rows[index]}", None) if col is not None: return col diff --git a/slisemap_interactive/plots.py b/slisemap_interactive/plots.py index 1a90134..1e79a71 100644 --- a/slisemap_interactive/plots.py +++ b/slisemap_interactive/plots.py @@ -25,7 +25,7 @@ from plotly.graph_objects import Figure from scipy.stats import gaussian_kde -from slisemap_interactive.load import get_L_column +from slisemap_interactive.load import PROTOTYPE_COLUMN, get_L_column PLOTLY_TEMPLATE = "slisemap_interactive" DEFAULT_TEMPLATE = "plotly_white+" + PLOTLY_TEMPLATE @@ -502,13 +502,13 @@ def plot( """Create the plot.""" def dfmod(var: str) -> pd.DataFrame: - df2 = pd.DataFrame( - {x: df[x], y: df[y], var: df[var], "index": np.arange(df.shape[0])} - ) + df2 = df[[x, y, var]].copy() + df2["index"] = pd.RangeIndex(df2.shape[0]) if jitter > 0: + mult = 1.0 - df.get(PROTOTYPE_COLUMN, 0.0) prng = np.random.default_rng(seed) - df2[x] += prng.normal(0, jitter, df.shape[0]) - df2[y] += prng.normal(0, jitter, df.shape[0]) + df2[x] += prng.normal(0, mult * jitter, df.shape[0]) + df2[y] += prng.normal(0, mult * jitter, df.shape[0]) return df2 fig = None @@ -522,10 +522,12 @@ def dfmod(var: str) -> pd.DataFrame: y=y, color=variable, color_discrete_sequence=px.colors.qualitative.Plotly, + opacity=(1.0 - df.get(PROTOTYPE_COLUMN, 0.0)) * 0.8, symbol=variable, category_orders={variable: cats}, title="Embedding", custom_data=["index"], + render_mode="webgl", ) fig.update_traces(hovertemplate=None, hoverinfo="none") ll = False @@ -533,11 +535,11 @@ def dfmod(var: str) -> pd.DataFrame: ll = variable == "Local loss" if fig is None and ll and hover is not None: losses = get_L_column(df, hover) - if losses is not None: + if losses is not None and np.isfinite(losses[hover]): loss_cols = [c for c in df.columns if c[:2] == "L_" or c[:3] == "LT_"] lrange = ( - df[loss_cols].abs().min().quantile(0.05) * 0.9, - df[loss_cols].abs().max().quantile(0.95) * 1.1, + df[loss_cols].min().quantile(0.05) * 0.9, + df[loss_cols].max().quantile(0.95) * 1.1, ) df2 = dfmod(variable) df2[variable] = losses @@ -552,6 +554,7 @@ def dfmod(var: str) -> pd.DataFrame: labels={variable: "Local loss "}, custom_data=["index"], range_color=lrange, + render_mode="webgl", ) if fig is None: df2 = dfmod(variable) @@ -562,9 +565,10 @@ def dfmod(var: str) -> pd.DataFrame: color=variable, color_continuous_scale="Plasma_r", title="Embedding", - opacity=0.8, + opacity=(1.0 - df.get(PROTOTYPE_COLUMN, 0.0)) * 0.8, labels={variable: "Local loss "} if ll else None, custom_data=["index"], + render_mode="webgl", ) if ll: fig.update_traces(hovertemplate=None, hoverinfo="none") @@ -583,6 +587,24 @@ def dfmod(var: str) -> pd.DataFrame: line_color="grey", line_width=1, ) + if hover is None and df.get(PROTOTYPE_COLUMN) is not None: + trace = px.scatter( + df[df[PROTOTYPE_COLUMN]], + x=x, + y=y, + render_mode="webgl", + opacity=0.8, + ).update_traces( + hovertemplate=None, + hoverinfo="skip", + marker={ + "size": 8, + "symbol": "hexagon2", + "color": "rgba(0,0,0,0)", + "line": {"width": 1, "color": "black"}, + }, + ) + fig.add_traces(trace.data) if hover is not None: trace = px.scatter(df2.iloc[[hover]], x=x, y=y).update_traces( hoverinfo="skip", @@ -887,13 +909,15 @@ def plot( if plot_type == "Histogram": fig = px.histogram(df, variable, title=f"Histogram of {variable}") else: - fig = ff.create_distplot([df[variable]], [variable], show_hist=False) + fig = ff.create_distplot( + [df[variable].dropna()], [variable], show_hist=False + ) fig.layout.yaxis.domain = [0.21, 1] fig.layout.yaxis2.domain = [0, 0.19] fig.update_layout( showlegend=False, title=f"Density plot for {variable}" ) - if hover is not None: + if hover is not None and np.isfinite(df[variable].iloc[hover]): fig.add_vline(x=df[variable].iloc[hover]) fig.update_layout(template=template, xaxis_title=None, yaxis_title=None) return fig diff --git a/slisemap_interactive/xiplot.py b/slisemap_interactive/xiplot.py index f5e0624..74a5951 100644 --- a/slisemap_interactive/xiplot.py +++ b/slisemap_interactive/xiplot.py @@ -18,6 +18,7 @@ from slisemap_interactive.load import ( DEFAULT_MAX_L, DEFAULT_MAX_N, + slipmap_to_dataframe, slisemap_to_dataframe, ) from slisemap_interactive.plots import ( @@ -65,7 +66,7 @@ def __init__( super().__init__(*children, **kwargs) -def plugin_load() -> Tuple[Callable[[object], pd.DataFrame], str]: +def load_slisemap() -> Tuple[Callable[[object], pd.DataFrame], str]: """Xiplot plugin for reading Slisemap files. Returns: @@ -83,6 +84,22 @@ def load( return load, ".sm" +def load_slipmap() -> Tuple[Callable[[object], pd.DataFrame], str]: + """Xiplot plugin for reading Slipmap files. + + Returns: + parser: Function for parsing a Slipmap file to a dataframe. + extension: File extension. + """ + # TODO Some columns should probably be hidden from the normal plots + + def load(data: object, max_n: int = DEFAULT_MAX_N) -> pd.DataFrame: + """Load the Slipmap.""" + return slipmap_to_dataframe(data, max_n=max_n, index=False) + + return load, ".sp" + + class SlisemapEmbeddingPlot(APlot): """Embedding plot for Slisemap.""" diff --git a/tests/test_load.py b/tests/test_load.py index 271f21b..83ff18d 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,11 +1,12 @@ import numpy as np import pytest -from slisemap import Slisemap +from slisemap import Slipmap, Slisemap from slisemap_interactive.load import ( get_L_column, load, save_dataframe, + slipmap_to_dataframe, slisemap_to_dataframe, subsample, ) @@ -62,6 +63,15 @@ def test_load_slisemap(sm_to_df): slisemap_to_dataframe(sm, losses=20, max_n=20, clusters=0, index=False) +def test_load_slipmap(sm_to_df): + sm, dfm = sm_to_df + sp = Slipmap.convert(sm) + dfp = slipmap_to_dataframe(sp) + for col in dfm.columns: + if col[0] not in ("L", "B", "Ŷ", "C"): + assert np.allclose(dfm[col], dfp[col][: dfm.shape[0]], 1e-4) + + def test_rec_l(sm_to_df): sm, df1 = sm_to_df df2 = slisemap_to_dataframe(sm, losses=30, clusters=0) diff --git a/tests/test_plots.py b/tests/test_plots.py index 295d812..1d618fb 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,5 +1,7 @@ +import numpy as np import pandas as pd import pytest +from slisemap_interactive.load import PROTOTYPE_COLUMN from slisemap_interactive.plots import ( BarGroupingDropdown, @@ -22,14 +24,15 @@ @pytest.fixture(scope="session") def dataframe(): df = pd.DataFrame() - df["cls"] = pd.Categorical([1, 2, 3, 1, 2, 3]) - df["Local loss"] = [0.1, 0.2, 0.3, 0.3, 0.2, 0.1] - df["L_0"] = df["L_1"] = df["L_2"] = [0.3, 0.2, 0.1, 0.1, 0.2, 0.3] - df["B_0"] = df["B_1"] = df["B_2"] = [0.3, 0.2, 0.1, 0.3, 0.2, 0.1] - df["X_0"] = df["X_1"] = [1, 2, 3, 1, 2, 3] - df["Y_0"] = [4, 6, -2, -4, 4, 0] - df["Z_0"] = [3, 2, 3, 3, 1, 3] - df["Z_1"] = [1, 2, 3, 3, 1, 2] + df["cls"] = pd.Categorical([1, 2, 3, 1, 2, 3, np.nan]) + df["Local loss"] = [0.1, 0.2, 0.3, 0.3, 0.2, 0.1, np.nan] + df["L_0"] = df["L_1"] = df["L_2"] = [0.3, 0.2, 0.1, 0.1, 0.2, 0.3, np.nan] + df["B_0"] = df["B_1"] = df["B_2"] = [0.3, 0.2, 0.1, 0.3, 0.2, 0.1, 0.0] + df["X_0"] = df["X_1"] = [1, 2, 3, 1, 2, 3, np.nan] + df["Y_0"] = [4, 6, -2, -4, 4, 0, np.nan] + df["Z_0"] = [3, 2, 3, 3, 1, 3, 0] + df["Z_1"] = [1, 2, 3, 3, 1, 2, 0] + df[PROTOTYPE_COLUMN] = [False] * 6 + [True] return df diff --git a/tests/test_xiplot.py b/tests/test_xiplot.py index c1d3b48..50d7a3b 100644 --- a/tests/test_xiplot.py +++ b/tests/test_xiplot.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd -from slisemap import Slisemap +from slisemap import Slipmap, Slisemap from xiplot.plugin import APlot, AReadPlugin from slisemap_interactive.plots import JitterSlider @@ -16,7 +16,8 @@ SlisemapLinearTermsPlot, SlisemapModelBarPlot, SlisemapModelMatrixPlot, - plugin_load, + load_slipmap, + load_slisemap, ) @@ -74,8 +75,9 @@ def type_to_annotation(typ: type, reference: Callable): def test_load_signature(): - assert_annotation_match(plugin_load, type_to_annotation(AReadPlugin, plugin_load)) - assert_annotation_match(plugin_load, signature_to_annotation(plugin_load)) + for load in [load_slisemap, load_slipmap]: + assert_annotation_match(load, type_to_annotation(AReadPlugin, load)) + assert_annotation_match(load, signature_to_annotation(load)) def test_load(): @@ -86,8 +88,14 @@ def test_load(): with BytesIO() as io: sm.save(io) io.seek(0) - sm2 = plugin_load()[0](io) + sm2 = load_slisemap()[0](io) assert sm2.shape == (10, 30) + sp = Slipmap.convert(sm) + with BytesIO() as io: + sp.save(io) + io.seek(0) + sp2 = load_slipmap()[0](io) + assert sp2.shape[0] == 10 def test_plot_signature(): From 1f17b98a209001cf8ea5d0ba3b9c83460131a3a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Tue, 20 Feb 2024 13:30:15 +0200 Subject: [PATCH 04/11] optimise loading --- slisemap_interactive/load.py | 95 ++++++++++++++++++---------------- slisemap_interactive/xiplot.py | 13 +++-- tests/test_load.py | 25 ++++++--- tests/test_plots.py | 2 +- tests/test_xiplot.py | 10 ++-- 5 files changed, 81 insertions(+), 64 deletions(-) diff --git a/slisemap_interactive/load.py b/slisemap_interactive/load.py index c093c2f..dead351 100644 --- a/slisemap_interactive/load.py +++ b/slisemap_interactive/load.py @@ -50,7 +50,7 @@ def load(cls, *args: Any, **kwargs: Any) -> "Slipmap": # Defaults for subsampling the Slisemap object DEFAULT_MAX_N = 5000 -DEFAULT_MAX_L = 250 +DEFAULT_MAX_L = 200 INDEX_COLUMN = "item" PROTOTYPE_COLUMN = "Slipmap Prototype" @@ -89,7 +89,7 @@ def subsample(Z: np.ndarray, n: int, clusters: Optional[int] = None) -> np.ndarr def slisemap_to_dataframe( path: Union[str, PathLike, Slisemap], losses: Union[bool, int] = True, - clusters: int = 9, + clusters: Optional[range] = range(2, 9), max_n: int = -1, index: bool = True, ) -> pd.DataFrame: @@ -98,7 +98,7 @@ def slisemap_to_dataframe( Args: path: Slisemap object or path to a saved Slisemap object. losses: Return the loss matrix. Can also be a number specifying the (approximate) maximum number of `L_*` columns. Default to True. - clusters: Return cluster indices (if greater than one). Defaults to 9. + clusters: Return cluster indices. Defaults to range(2, 9). max_n: maximum number of data items in the dataframe (subsampling is recommended if `n > 5000` and `losses=True`). Defaults to -1. index: Return row names as the index (True) or as an "item" column (False). Defaults to True. @@ -131,61 +131,64 @@ def preface_names(names: Sequence, preface: str) -> List[str]: has_index = False rows = range(sm.n) + pred = sm.predict(X=sm._X[ss, :], B=sm._B[ss, :]) + local_loss = sm.local_loss(sm._as_new_Y(pred), sm._Y[ss, :]).detach().cpu().numpy() + B = sm.get_B()[ss, :] + Z = Z[ss, :] dfs = [ pd.DataFrame.from_records(sm.metadata.unscale_X()[ss, :], columns=variables), pd.DataFrame.from_records(sm.metadata.unscale_Y()[ss, :], columns=targets), - pd.DataFrame.from_records(Z[ss, :], columns=dimensions), - pd.DataFrame.from_records(sm.get_B()[ss, :], columns=coefficients), - pd.DataFrame.from_records( - sm.metadata.unscale_Y(sm.predict(X=sm._X[ss, :], B=sm._B[ss, :])), - columns=predictions, - ), + pd.DataFrame.from_records(Z, columns=dimensions), + pd.DataFrame.from_records(B, columns=coefficients), + pd.DataFrame.from_records(sm.metadata.unscale_Y(pred), columns=predictions), + pd.DataFrame({"Local loss": local_loss}), ] - del variables, targets, dimensions, coefficients, predictions - gc.collect(1) + del variables, targets, dimensions, coefficients, predictions, local_loss, pred - L = sm.get_L(X=sm._X[ss, :], Y=sm._Y[ss, :])[ss, :] - dfs.append(pd.DataFrame({"Local loss": L.diagonal()})) if not isinstance(losses, bool) and losses > 0 and losses * 2 < Z.shape[0]: - sel = subsample(Z[ss, :], losses) + sel = subsample(Z, losses) sel.sort() + L = sm.local_loss(sm.local_model(sm._X[ss], sm._B[ss][sel]), sm._Y[ss]) + L = L.detach().cpu().numpy() Ln = [f"LT_{rows[i]}" for i in sel] - dfs.append(pd.DataFrame.from_records(L.T[:, sel], columns=Ln)) + dfs.append(pd.DataFrame.from_records(L.T, columns=Ln)) + del L, Ln elif losses: + L = sm.local_loss(sm.local_model(sm._X[ss], sm._B[ss]), sm._Y[ss]) + L = L.detach().cpu().numpy() Ln = [f"L_{i}" for i in rows] dfs.append(pd.DataFrame.from_records(L, columns=Ln)) - del L, Z - gc.collect(1) + del L, Ln - if clusters > 1: + if clusters is not None and len(clusters) > 0: with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) clusters = { f"Clusters {i}": pd.Series( - sm.get_model_clusters(i)[0][ss], dtype="category" + sm.get_model_clusters(i, B=B, Z=Z)[0], dtype="category" ) - for i in range(2, clusters + 1) + for i in clusters } dfs.append(pd.DataFrame(clusters)) + del clusters + del sm, Z, B + gc.collect(1) # Then we create a dataframe to return - del sm - gc.collect(1) df = pd.concat(dfs, axis=1, copy=False) + del dfs if has_index: if index: df.index = rows else: df.insert(0, INDEX_COLUMN, rows) - del dfs - gc.collect(1) return df def slipmap_to_dataframe( path: Union[str, PathLike, Slipmap], losses: bool = True, - clusters: int = 9, + clusters: Optional[range] = range(2, 9), max_n: int = -1, index: bool = True, ) -> pd.DataFrame: @@ -194,7 +197,7 @@ def slipmap_to_dataframe( Args: path: Slipmap object or path to a saved Slipmap object. losses: Return the loss matrix. Can also be a number specifying the (approximate) maximum number of `L_*` columns. Default to True. - clusters: Return cluster indices (if greater than one). Defaults to 9. + clusters: Return cluster indices. Defaults to range(2, 9). max_n: maximum number of data items in the dataframe (subsampling is recommended if `n > 5000` and `losses=True`). Defaults to -1. index: Return row names as the index (True) or as an "item" column (False). Defaults to True. @@ -230,37 +233,40 @@ def preface_names(names: Sequence, preface: str) -> List[str]: pred = sp.predict(sp._X[ss, :]) local_loss = sp.local_loss(sp._Y, sp._as_new_Y(pred)).detach().cpu().numpy() + B = sp.get_B()[ss, :] dfs = [ pd.DataFrame.from_records(sp.metadata.unscale_X()[ss, :], columns=variables), pd.DataFrame.from_records(sp.metadata.unscale_Y()[ss, :], columns=targets), pd.DataFrame.from_records(Z[ss, :], columns=dimensions), - pd.DataFrame.from_records(sp.get_B()[ss, :], columns=coefficients), + pd.DataFrame.from_records(B, columns=coefficients), pd.DataFrame.from_records(sp.metadata.unscale_Y(pred), columns=predictions), pd.DataFrame({"Local loss": local_loss}), ] - del variables, targets, predictions, Z, pred, local_loss - gc.collect(1) - - dfs2 = [ - pd.DataFrame.from_records(sp.get_Zp(), columns=dimensions), - pd.DataFrame.from_records(sp.get_Bp(), columns=coefficients), - ] - if losses: - L = sp.get_L(X=sp._X[ss, :], Y=sp._Y[ss, :])[ss, :] - Ln = [f"LT_{i}" for i in rows_proto] - dfs.append(pd.DataFrame.from_records(L.T, columns=Ln)) - del L + del variables, targets, predictions, pred, local_loss - if clusters > 1: + if clusters is not None and len(clusters) > 0: with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) clusters = { f"Clusters {i}": pd.Series( - sp.get_model_clusters(i)[0][ss], dtype="category" + sp.get_model_clusters(i, B=B, Z=Z[ss, :])[0], dtype="category" ) - for i in range(2, clusters + 1) + for i in clusters } dfs.append(pd.DataFrame(clusters)) + del clusters + del Z, B + + if losses: + L = sp.get_L(X=sp._X[ss, :], Y=sp._Y[ss, :]) + Ln = [f"LT_{i}" for i in rows_proto] + dfs.append(pd.DataFrame.from_records(L.T, columns=Ln)) + del L + + dfs2 = [ + pd.DataFrame.from_records(sp.get_Zp(), columns=dimensions), + pd.DataFrame.from_records(sp.get_Bp(), columns=coefficients), + ] del sp, dimensions, coefficients gc.collect(1) @@ -277,10 +283,7 @@ def preface_names(names: Sequence, preface: str) -> List[str]: else: df1.insert(0, INDEX_COLUMN, rows) df2.insert(0, INDEX_COLUMN, rows_proto) - df = pd.concat((df1, df2), axis=0, copy=False) - del df1, df2 - gc.collect(1) - return df + return pd.concat((df1, df2), axis=0, copy=False) def _extract_extension(path: Union[str, PathLike]) -> str: diff --git a/slisemap_interactive/xiplot.py b/slisemap_interactive/xiplot.py index 74a5951..9a210e3 100644 --- a/slisemap_interactive/xiplot.py +++ b/slisemap_interactive/xiplot.py @@ -76,10 +76,15 @@ def load_slisemap() -> Tuple[Callable[[object], pd.DataFrame], str]: # TODO Some columns should probably be hidden from the normal plots def load( - data: object, max_n: int = DEFAULT_MAX_N, max_l: int = DEFAULT_MAX_L + data: object, + max_n: int = DEFAULT_MAX_N, + max_l: int = DEFAULT_MAX_L, + **kwargs: Any, ) -> pd.DataFrame: """Load the Slisemap.""" - return slisemap_to_dataframe(data, max_n=max_n, index=False, losses=max_l) + return slisemap_to_dataframe( + data, max_n=max_n, index=False, losses=max_l, **kwargs + ) return load, ".sm" @@ -93,9 +98,9 @@ def load_slipmap() -> Tuple[Callable[[object], pd.DataFrame], str]: """ # TODO Some columns should probably be hidden from the normal plots - def load(data: object, max_n: int = DEFAULT_MAX_N) -> pd.DataFrame: + def load(data: object, max_n: int = DEFAULT_MAX_N, **kwargs: Any) -> pd.DataFrame: """Load the Slipmap.""" - return slipmap_to_dataframe(data, max_n=max_n, index=False) + return slipmap_to_dataframe(data, max_n=max_n, index=False, **kwargs) return load, ".sp" diff --git a/tests/test_load.py b/tests/test_load.py index 83ff18d..858c81c 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -18,7 +18,7 @@ def sm_to_df(): Y = np.random.normal(0, 1, 100) B0 = np.random.normal(0, 1, (100, 6)) sm = Slisemap(X, Y, lasso=0.1, B0=B0) - df = slisemap_to_dataframe(sm, losses=True, clusters=8) + df = slisemap_to_dataframe(sm, losses=True) return sm, df @@ -47,38 +47,47 @@ def test_load_slisemap(sm_to_df): assert np.allclose(sm.get_Z(rotate=True)[:, 0], df["Z_0"]) assert df.shape[0] == sm.n assert df.shape[1] == 1 + sm.n + 7 + sm.m + sm.q + sm.o * 2 + sm.d - sm.intercept - df2 = slisemap_to_dataframe(sm, max_n=80, index=False, losses=False, clusters=3) + df2 = slisemap_to_dataframe(sm, max_n=80, index=False, losses=False, clusters=(3,)) assert df2.shape[0] == 80 sm.metadata.set_rows(range(1, sm.n + 1)) sm.metadata.set_variables(range(1, sm.m + 1 - sm.intercept)) sm.metadata.set_targets("test") sm.metadata.set_coefficients(sm.metadata.get_variables()) sm.metadata.set_dimensions("as") - df3 = slisemap_to_dataframe(sm, losses=10, clusters=0) + df3 = slisemap_to_dataframe(sm, losses=10, clusters=None) assert all(f"X_{i}" in df3 for i in sm.metadata.get_variables(intercept=False)) assert all(f"B_{i}" in df3 for i in sm.metadata.get_coefficients()) assert all(f"Y_{i}" in df3 for i in sm.metadata.get_targets()) assert all(f"Z_{i}" in df3 for i in sm.metadata.get_dimensions()) assert np.allclose(df3.index, sm.metadata.get_rows()) - slisemap_to_dataframe(sm, losses=20, max_n=20, clusters=0, index=False) + slisemap_to_dataframe(sm, losses=20, max_n=20, clusters=None, index=False) def test_load_slipmap(sm_to_df): sm, dfm = sm_to_df sp = Slipmap.convert(sm) - dfp = slipmap_to_dataframe(sp) + dfp = slipmap_to_dataframe(sp, clusters=None, losses=False) for col in dfm.columns: if col[0] not in ("L", "B", "Ŷ", "C"): - assert np.allclose(dfm[col], dfp[col][: dfm.shape[0]], 1e-4) + assert np.allclose( + dfm[col], dfp[col][: dfm.shape[0]], 2e-4 + ), f"{col} not equal" + dfp = slipmap_to_dataframe(sp, clusters=range(6, 7), losses=True) + for col in dfp.columns: + if col[:3] == "LT_": + assert np.all(np.isnan(dfp[col][sp.n :])) + assert np.all(np.isfinite(dfp[col][: sp.n])) + assert int(col[3:]) >= sp.n + assert np.sum(np.isnan(get_L_column(dfp, 0))) == sp.n - 1 def test_rec_l(sm_to_df): sm, df1 = sm_to_df - df2 = slisemap_to_dataframe(sm, losses=30, clusters=0) + df2 = slisemap_to_dataframe(sm, losses=30, clusters=None) for i in range(df1.shape[0]): l1 = get_L_column(df1, i) l2 = get_L_column(df2, i) - assert np.all(np.equal(l1, l2) + (np.isnan(l2))) + assert np.all(np.isclose(l1, l2) + np.isnan(l2)) def test_save(sm_to_df, tmp_path): diff --git a/tests/test_plots.py b/tests/test_plots.py index 1d618fb..c4d1b6f 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,8 +1,8 @@ import numpy as np import pandas as pd import pytest -from slisemap_interactive.load import PROTOTYPE_COLUMN +from slisemap_interactive.load import PROTOTYPE_COLUMN from slisemap_interactive.plots import ( BarGroupingDropdown, ClusterDropdown, diff --git a/tests/test_xiplot.py b/tests/test_xiplot.py index 50d7a3b..382c96c 100644 --- a/tests/test_xiplot.py +++ b/tests/test_xiplot.py @@ -88,14 +88,14 @@ def test_load(): with BytesIO() as io: sm.save(io) io.seek(0) - sm2 = load_slisemap()[0](io) - assert sm2.shape == (10, 30) - sp = Slipmap.convert(sm) + sm2 = load_slisemap()[0](io, clusters=None) + assert sm2.shape == (10, 22) + sp = Slipmap.convert(sm, prototypes=10) with BytesIO() as io: sp.save(io) io.seek(0) - sp2 = load_slipmap()[0](io) - assert sp2.shape[0] == 10 + sp2 = load_slipmap()[0](io, clusters=None) + assert sp2.shape == (20, 23) def test_plot_signature(): From fab6fcd819034f4317f42d6160dd4da5e7af5ceb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Tue, 20 Feb 2024 13:53:08 +0200 Subject: [PATCH 05/11] update GitHub actions --- .github/workflows/python-publish.yml | 20 +++++-------- .github/workflows/python-pytest.yml | 44 ++++++++++++++++++++-------- pyproject.toml | 9 +----- tests/test_layout.py | 19 ++++++++++++ 4 files changed, 60 insertions(+), 32 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 43fc772..ee3e604 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -11,23 +11,19 @@ on: jobs: deploy: runs-on: ubuntu-latest + permissions: + id-token: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install dependencies run: | - python -m pip install --upgrade pip - python -m pip install build + python -m pip install --upgrade pip build - name: Build package - run: | - python -m build - python -c "import os, glob; assert os.path.getsize(sorted(glob.glob('dist/*-*.whl'))[-1]) > 10_000" - - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} + run: python -m build + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/python-pytest.yml b/.github/workflows/python-pytest.yml index 0e1c3bb..92b0833 100644 --- a/.github/workflows/python-pytest.yml +++ b/.github/workflows/python-pytest.yml @@ -4,7 +4,6 @@ name: tests on: - workflow_dispatch: push: branches: [master, main] paths: ["**.py"] @@ -12,27 +11,48 @@ on: paths: ["**.py"] jobs: - build: + test: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + cache: "pip" - name: Install dependencies run: | - python -m pip install --upgrade pip - python -m pip install pytest build - python -m pip install --editable . + python -m pip install --upgrade pip pytest pytest-cov + python -m pip install -e ".[xiplot]" + - name: Test with pytest + run: | + pytest --cov-report term --cov=slisemap_interactive --cov-fail-under=8 + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - run: python -m pip install --upgrade pip build - name: Build package run: | python -m build - python -c "import os, glob; assert os.path.getsize(sorted(glob.glob('dist/*-*.whl'))[-1]) > 10_000" - - name: Test with pytest + python -c "import os, glob; assert os.path.getsize(sorted(glob.glob('dist/slisemap-*.whl'))[-1]) > 10_000" + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - run: python -m pip install --upgrade pip ruff + - name: Lint with Ruff run: | - pytest + ruff check --output-format=github + ruff format --check diff --git a/pyproject.toml b/pyproject.toml index e6518c1..fd1bf71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,14 +56,7 @@ packages = ["slisemap_interactive"] branch = true [tool.coverage.report] -exclude_also = [ - "_deprecated", - "print", - "plt.show", - "if verbose", - "ImportError", - "_warn", -] +exclude_also = ["print", "ImportError", "callback"] [tool.ruff.lint] select = [ diff --git a/tests/test_layout.py b/tests/test_layout.py index fac63d6..cfc8aaa 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -1,7 +1,26 @@ +import sys + import pandas as pd +import pytest +from slisemap_interactive.app import BackgroundApp, cli, shutdown from slisemap_interactive.layout import page_with_all_plots def test_layout(): page_with_all_plots(pd.DataFrame(), 0) + + +def test_background(): + app = BackgroundApp() + app.set_data(pd.DataFrame()) + app.shutdown() + shutdown() + + +def test_cli_parse(): + old_argv = sys.argv + sys.argv = [old_argv[0], "--export", "", ""] + with pytest.raises(FileNotFoundError): + cli() + sys.argv = old_argv From d90b0b10b9b406d38970f7aa1efd36b37255bf4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Tue, 20 Feb 2024 14:00:36 +0200 Subject: [PATCH 06/11] catch kmeans warnings in slisemap --- slisemap_interactive/load.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/slisemap_interactive/load.py b/slisemap_interactive/load.py index dead351..a5fa591 100644 --- a/slisemap_interactive/load.py +++ b/slisemap_interactive/load.py @@ -161,16 +161,14 @@ def preface_names(names: Sequence, preface: str) -> List[str]: del L, Ln if clusters is not None and len(clusters) > 0: - with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - clusters = { - f"Clusters {i}": pd.Series( - sm.get_model_clusters(i, B=B, Z=Z)[0], dtype="category" - ) - for i in clusters - } - dfs.append(pd.DataFrame(clusters)) - del clusters + clusters = { + f"Clusters {i}": pd.Series( + sm.get_model_clusters(i, B=B, Z=Z)[0], dtype="category" + ) + for i in clusters + } + dfs.append(pd.DataFrame(clusters)) + del clusters del sm, Z, B gc.collect(1) @@ -245,16 +243,14 @@ def preface_names(names: Sequence, preface: str) -> List[str]: del variables, targets, predictions, pred, local_loss if clusters is not None and len(clusters) > 0: - with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - clusters = { - f"Clusters {i}": pd.Series( - sp.get_model_clusters(i, B=B, Z=Z[ss, :])[0], dtype="category" - ) - for i in clusters - } - dfs.append(pd.DataFrame(clusters)) - del clusters + clusters = { + f"Clusters {i}": pd.Series( + sp.get_model_clusters(i, B=B, Z=Z[ss, :])[0], dtype="category" + ) + for i in clusters + } + dfs.append(pd.DataFrame(clusters)) + del clusters del Z, B if losses: From aa757e2dba63af7507644faaf922a5acc506b528 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Tue, 20 Feb 2024 14:15:19 +0200 Subject: [PATCH 07/11] plot tweaks --- slisemap_interactive/plots.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/slisemap_interactive/plots.py b/slisemap_interactive/plots.py index 1e79a71..370d7cb 100644 --- a/slisemap_interactive/plots.py +++ b/slisemap_interactive/plots.py @@ -549,7 +549,7 @@ def dfmod(var: str) -> pd.DataFrame: y=y, color=variable, title=f"Alternative locations for item: {df.get('item', df.index)[hover]}", - opacity=np.isfinite(losses) * 0.8, + opacity=np.isfinite(losses) * 0.8 + 0.05, color_continuous_scale="Viridis_r", labels={variable: "Local loss "}, custom_data=["index"], @@ -587,7 +587,7 @@ def dfmod(var: str) -> pd.DataFrame: line_color="grey", line_width=1, ) - if hover is None and df.get(PROTOTYPE_COLUMN) is not None: + if df.get(PROTOTYPE_COLUMN) is not None: trace = px.scatter( df[df[PROTOTYPE_COLUMN]], x=x, @@ -601,7 +601,7 @@ def dfmod(var: str) -> pd.DataFrame: "size": 8, "symbol": "hexagon2", "color": "rgba(0,0,0,0)", - "line": {"width": 1, "color": "black"}, + "line": {"width": 1, "color": "grey"}, }, ) fig.add_traces(trace.data) From dd38bae8b679cc3df6a43db7ad8d449119e77f7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Tue, 20 Feb 2024 14:45:45 +0200 Subject: [PATCH 08/11] more integration tests --- .github/workflows/python-pytest.yml | 2 +- pyproject.toml | 2 +- slisemap_interactive/app.py | 2 +- slisemap_interactive/xiplot.py | 2 +- tests/{test_layout.py => test_app.py} | 9 ++++++--- tests/test_load.py | 2 +- tests/test_xiplot.py | 4 ++++ 7 files changed, 15 insertions(+), 8 deletions(-) rename tests/{test_layout.py => test_app.py} (57%) diff --git a/.github/workflows/python-pytest.yml b/.github/workflows/python-pytest.yml index 92b0833..c7eda47 100644 --- a/.github/workflows/python-pytest.yml +++ b/.github/workflows/python-pytest.yml @@ -29,7 +29,7 @@ jobs: python -m pip install -e ".[xiplot]" - name: Test with pytest run: | - pytest --cov-report term --cov=slisemap_interactive --cov-fail-under=8 + pytest --cov-report term --cov=slisemap_interactive --cov-fail-under=70 build: runs-on: ubuntu-latest diff --git a/pyproject.toml b/pyproject.toml index fd1bf71..f539c98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ packages = ["slisemap_interactive"] branch = true [tool.coverage.report] -exclude_also = ["print", "ImportError", "callback"] +exclude_also = ["print", "ImportError"] [tool.ruff.lint] select = [ diff --git a/slisemap_interactive/app.py b/slisemap_interactive/app.py index 305abfd..6adafaf 100644 --- a/slisemap_interactive/app.py +++ b/slisemap_interactive/app.py @@ -278,7 +278,7 @@ def display( Exception: The server must be started (through `BackgroundApp().run()`) before the plots are displayed. """ if self._display_call is None: - raise Exception( + raise RuntimeError( "You need to run `BackgroundApp().run()` before displaying results" ) if mode is None: diff --git a/slisemap_interactive/xiplot.py b/slisemap_interactive/xiplot.py index 9a210e3..ee10233 100644 --- a/slisemap_interactive/xiplot.py +++ b/slisemap_interactive/xiplot.py @@ -477,7 +477,7 @@ def name(cls) -> str: # @classmethod # def help(cls) -> str: - """Help string.""" + # """Help string.""" # return "Histogram for Slisemap objects" @classmethod diff --git a/tests/test_layout.py b/tests/test_app.py similarity index 57% rename from tests/test_layout.py rename to tests/test_app.py index cfc8aaa..08f0574 100644 --- a/tests/test_layout.py +++ b/tests/test_app.py @@ -3,7 +3,7 @@ import pandas as pd import pytest -from slisemap_interactive.app import BackgroundApp, cli, shutdown +from slisemap_interactive.app import BackgroundApp, _can_display_iframe, cli, shutdown from slisemap_interactive.layout import page_with_all_plots @@ -14,13 +14,16 @@ def test_layout(): def test_background(): app = BackgroundApp() app.set_data(pd.DataFrame()) + with pytest.raises(RuntimeError): + app.display() app.shutdown() shutdown() + _can_display_iframe() -def test_cli_parse(): +def test_cli_parse(tmp_path): old_argv = sys.argv - sys.argv = [old_argv[0], "--export", "", ""] + sys.argv = [old_argv[0], "--export", str(tmp_path / "out"), str(tmp_path / "in")] with pytest.raises(FileNotFoundError): cli() sys.argv = old_argv diff --git a/tests/test_load.py b/tests/test_load.py index 858c81c..1ea6c99 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -71,7 +71,7 @@ def test_load_slipmap(sm_to_df): if col[0] not in ("L", "B", "Ŷ", "C"): assert np.allclose( dfm[col], dfp[col][: dfm.shape[0]], 2e-4 - ), f"{col} not equal" + ), f"{col} not equal ({np.abs(dfm[col], dfp[col][: dfm.shape[0]]).max()})" dfp = slipmap_to_dataframe(sp, clusters=range(6, 7), losses=True) for col in dfp.columns: if col[:3] == "LT_": diff --git a/tests/test_xiplot.py b/tests/test_xiplot.py index 382c96c..3049fa0 100644 --- a/tests/test_xiplot.py +++ b/tests/test_xiplot.py @@ -7,6 +7,7 @@ from slisemap import Slipmap, Slisemap from xiplot.plugin import APlot, AReadPlugin +from slisemap_interactive.app import BackgroundApp from slisemap_interactive.plots import JitterSlider from slisemap_interactive.xiplot import ( LabelledControls, @@ -125,8 +126,11 @@ def test_plots(): SlisemapModelBarPlot, SlisemapModelMatrixPlot, ] + app = BackgroundApp() for plot in plots: plot.create_layout(0, pd.DataFrame(), None, {}) + plot.register_callbacks(app, lambda *_: None, lambda *_: None) + plot.name() def test_labelled_controls(): From 24cb2e0d659774765985c422430eb6f85ce8bbc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Tue, 20 Feb 2024 15:01:49 +0200 Subject: [PATCH 09/11] fix tests --- .github/workflows/python-pytest.yml | 3 ++- tests/test_load.py | 9 +++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/python-pytest.yml b/.github/workflows/python-pytest.yml index c7eda47..52ae171 100644 --- a/.github/workflows/python-pytest.yml +++ b/.github/workflows/python-pytest.yml @@ -41,8 +41,9 @@ jobs: - run: python -m pip install --upgrade pip build - name: Build package run: | + rm -f dist/*.whl python -m build - python -c "import os, glob; assert os.path.getsize(sorted(glob.glob('dist/slisemap-*.whl'))[-1]) > 10_000" + python -c "import os, glob; assert os.path.getsize(glob.glob('dist/*.whl')[-1]) > 10_000" lint: runs-on: ubuntu-latest diff --git a/tests/test_load.py b/tests/test_load.py index 1ea6c99..f03ff34 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -65,15 +65,12 @@ def test_load_slisemap(sm_to_df): def test_load_slipmap(sm_to_df): sm, dfm = sm_to_df - sp = Slipmap.convert(sm) + sp = Slipmap.convert(sm, prototypes=10) dfp = slipmap_to_dataframe(sp, clusters=None, losses=False) - for col in dfm.columns: - if col[0] not in ("L", "B", "Ŷ", "C"): - assert np.allclose( - dfm[col], dfp[col][: dfm.shape[0]], 2e-4 - ), f"{col} not equal ({np.abs(dfm[col], dfp[col][: dfm.shape[0]]).max()})" dfp = slipmap_to_dataframe(sp, clusters=range(6, 7), losses=True) for col in dfp.columns: + if col[0] in ("X", "Y"): + assert np.allclose(dfm[col], dfp[col][: dfm.shape[0]]) if col[:3] == "LT_": assert np.all(np.isnan(dfp[col][sp.n :])) assert np.all(np.isfinite(dfp[col][: sp.n])) From ee8bf809831ff37a65a6ff57b0da6d622d69bd81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Tue, 20 Feb 2024 15:18:27 +0200 Subject: [PATCH 10/11] fix gh action --- .github/workflows/python-pytest.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-pytest.yml b/.github/workflows/python-pytest.yml index 52ae171..0dc05e4 100644 --- a/.github/workflows/python-pytest.yml +++ b/.github/workflows/python-pytest.yml @@ -25,8 +25,9 @@ jobs: cache: "pip" - name: Install dependencies run: | - python -m pip install --upgrade pip pytest pytest-cov - python -m pip install -e ".[xiplot]" + python -m pip install --upgrade pip + python -m pip install xiplot pytest pytest-cov + python -m pip install -e . - name: Test with pytest run: | pytest --cov-report term --cov=slisemap_interactive --cov-fail-under=70 From b7b64ded597470a68d9f12b6810233f86b418a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Tue, 20 Feb 2024 15:34:25 +0200 Subject: [PATCH 11/11] update readme --- .github/workflows/python-pytest.yml | 2 +- README.md | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-pytest.yml b/.github/workflows/python-pytest.yml index 0dc05e4..6850639 100644 --- a/.github/workflows/python-pytest.yml +++ b/.github/workflows/python-pytest.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/README.md b/README.md index e1a172d..f259bc3 100644 --- a/README.md +++ b/README.md @@ -23,9 +23,11 @@ To use the plugin, just install the package in the same Python environment as [ ## Installation -To install __slisemap_interactive__ without manually downloading the repository run: +To install __slisemap_interactive__ run one of the following commands: ```bash +pip install slisemap_interactive +pip install slisemap_interactive[xiplot] pip install git+https://github.com/edahelsinki/slisemap_interactive ```