diff --git a/pyproject.toml b/pyproject.toml index 8620d21..de4dbab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ dynamic = ["version"] description = "Pytometry is a Python package for flow and mass cytometry analysis." requires-python = '>= 3.9' dependencies = [ - "nbproject", "numpy>=1.20.0", "numba>=0.57", "pandas<2.0.0,>=1.5.3", @@ -24,7 +23,6 @@ dependencies = [ "matplotlib", "readfcs >=1.1.0", "flowutils", - "datashader", "consensusclustering", "minisom" ] @@ -41,6 +39,7 @@ dev = [ test = [ "pytest>=6.0", "pytest-cov", + "nbproject", "nbproject_test >= 0.2.0", ] diff --git a/pytometry/plotting/_scatter_density.py b/pytometry/plotting/_scatter_density.py index 4e03cc4..ef0a6ff 100644 --- a/pytometry/plotting/_scatter_density.py +++ b/pytometry/plotting/_scatter_density.py @@ -1,13 +1,11 @@ from typing import Literal # noqa: TYP001 -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union -import datashader as ds +import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np -import pandas as pd -import scanpy as sc from anndata import AnnData -from datashader.mpl_ext import dsshow +from matplotlib import colormaps from matplotlib.axes import Axes from matplotlib.colors import Colormap from matplotlib.scale import ScaleBase @@ -25,7 +23,8 @@ def scatter_density( y_lim: Optional[Tuple[float, float]] = None, ax: Optional[Axes] = None, figsize: Optional[tuple[int, int]] = None, - cmap: Union[str, List, Colormap] = "jet", + bins: Union[int, tuple[int, int]] = 500, + cmap: Union[str, Colormap] = "jet", vmin: Optional[float] = None, vmax: Optional[float] = None, *, @@ -54,6 +53,8 @@ def scatter_density( draw into an existing figure. figsize (tuple), optional: Figure size (width, height) if ``ax`` not provided. Defaults to (10, 10). + bins (int or tuple), optional: + Number of bins for the `np.histogram2d` function cmap (str or list or :class:`matplotlib.colors.Colormap`), optional: For scalar aggregates, a matplotlib colormap name or instance. Alternatively, an iterable of colors can be passed and will be converted @@ -69,41 +70,43 @@ def scatter_density( Returns: Scatter plot that displays cell density """ - figsize = figsize if figsize is not None else (10, 10) ax = plt.subplots(figsize=figsize)[1] if ax is None else ax - if x_label is None: - x_label = x - if y_label is None: - y_label = y - # Create df from anndata object - markers = [x, y] - joined = sc.get.obs_df(adata, keys=[*markers], layer=layer) - # Convert variables to np.array - x = np.array(joined[x]) - y = np.array(joined[y]) + if isinstance(bins, int): + bins = (bins, bins) - # Plot density with datashader - df = pd.DataFrame(dict(x=x, y=y)) - dsartist = dsshow( - df, - ds.Point("x", "y"), - ds.count(), - vmin=vmin, - vmax=vmax, - norm=None, - # aspect="auto", - ax=ax, - cmap=cmap, + hist, xedges, yedges = np.histogram2d( + adata.obs_vector(x, layer=layer), adata.obs_vector(y, layer=layer), bins=bins ) - plt.colorbar(dsartist) + vmin = hist.min() if vmin is None else vmin + vmax = hist.max() if vmax is None else vmax - plt.xlim(x_lim) - plt.ylim(y_lim) - plt.yscale(x_scale) - plt.xscale(y_scale) - plt.xlabel(x_label) - plt.ylabel(y_label) + image = ax.imshow( + hist.T, + extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], + norm=mcolors.Normalize(vmin=vmin, vmax=vmax), + cmap=_get_cmap_white_background(cmap), + aspect="auto", + origin="lower", + ) + plt.colorbar(image, ax=ax) + + ax.set_xlim(x_lim) + ax.set_ylim(y_lim) + ax.set_yscale(x_scale) + ax.set_xscale(y_scale) + ax.set_xlabel(x if x_label is None else x_label) + ax.set_ylabel(y if y_label is None else y_label) plt.show() + + +def _get_cmap_white_background(cmap: Union[str, Colormap]) -> Colormap: + if isinstance(cmap, str): + cmap = colormaps.get_cmap(cmap) + + colors = cmap(np.arange(cmap.N)) + colors[0] = np.array([1, 1, 1, 1]) + + return mcolors.ListedColormap(colors)