Skip to content

Commit

Permalink
⚡ Use numpy histogram in scatter_density (#70)
Browse files Browse the repository at this point in the history
* use numpy histogram in scatter_density

* pre-commit changelog

* fix tests: use colormap from mpl

* remove unused List type

* remove changelog file
  • Loading branch information
quentinblampey authored May 20, 2024
1 parent aa44168 commit 8e74529
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 38 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ dynamic = ["version"]
description = "Pytometry is a Python package for flow and mass cytometry analysis."
requires-python = '>= 3.9'
dependencies = [
"nbproject",
"numpy>=1.20.0",
"numba>=0.57",
"pandas<2.0.0,>=1.5.3",
Expand All @@ -24,7 +23,6 @@ dependencies = [
"matplotlib",
"readfcs >=1.1.0",
"flowutils",
"datashader",
"consensusclustering",
"minisom"
]
Expand All @@ -41,6 +39,7 @@ dev = [
test = [
"pytest>=6.0",
"pytest-cov",
"nbproject",
"nbproject_test >= 0.2.0",
]

Expand Down
75 changes: 39 additions & 36 deletions pytometry/plotting/_scatter_density.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Literal # noqa: TYP001
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union

import datashader as ds
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from datashader.mpl_ext import dsshow
from matplotlib import colormaps
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.scale import ScaleBase
Expand All @@ -25,7 +23,8 @@ def scatter_density(
y_lim: Optional[Tuple[float, float]] = None,
ax: Optional[Axes] = None,
figsize: Optional[tuple[int, int]] = None,
cmap: Union[str, List, Colormap] = "jet",
bins: Union[int, tuple[int, int]] = 500,
cmap: Union[str, Colormap] = "jet",
vmin: Optional[float] = None,
vmax: Optional[float] = None,
*,
Expand Down Expand Up @@ -54,6 +53,8 @@ def scatter_density(
draw into an existing figure.
figsize (tuple), optional:
Figure size (width, height) if ``ax`` not provided. Defaults to (10, 10).
bins (int or tuple), optional:
Number of bins for the `np.histogram2d` function
cmap (str or list or :class:`matplotlib.colors.Colormap`), optional:
For scalar aggregates, a matplotlib colormap name or instance.
Alternatively, an iterable of colors can be passed and will be converted
Expand All @@ -69,41 +70,43 @@ def scatter_density(
Returns:
Scatter plot that displays cell density
"""
figsize = figsize if figsize is not None else (10, 10)
ax = plt.subplots(figsize=figsize)[1] if ax is None else ax
if x_label is None:
x_label = x
if y_label is None:
y_label = y
# Create df from anndata object
markers = [x, y]
joined = sc.get.obs_df(adata, keys=[*markers], layer=layer)

# Convert variables to np.array
x = np.array(joined[x])
y = np.array(joined[y])
if isinstance(bins, int):
bins = (bins, bins)

# Plot density with datashader
df = pd.DataFrame(dict(x=x, y=y))
dsartist = dsshow(
df,
ds.Point("x", "y"),
ds.count(),
vmin=vmin,
vmax=vmax,
norm=None,
# aspect="auto",
ax=ax,
cmap=cmap,
hist, xedges, yedges = np.histogram2d(
adata.obs_vector(x, layer=layer), adata.obs_vector(y, layer=layer), bins=bins
)

plt.colorbar(dsartist)
vmin = hist.min() if vmin is None else vmin
vmax = hist.max() if vmax is None else vmax

plt.xlim(x_lim)
plt.ylim(y_lim)
plt.yscale(x_scale)
plt.xscale(y_scale)
plt.xlabel(x_label)
plt.ylabel(y_label)
image = ax.imshow(
hist.T,
extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
norm=mcolors.Normalize(vmin=vmin, vmax=vmax),
cmap=_get_cmap_white_background(cmap),
aspect="auto",
origin="lower",
)
plt.colorbar(image, ax=ax)

ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_yscale(x_scale)
ax.set_xscale(y_scale)
ax.set_xlabel(x if x_label is None else x_label)
ax.set_ylabel(y if y_label is None else y_label)

plt.show()


def _get_cmap_white_background(cmap: Union[str, Colormap]) -> Colormap:
if isinstance(cmap, str):
cmap = colormaps.get_cmap(cmap)

colors = cmap(np.arange(cmap.N))
colors[0] = np.array([1, 1, 1, 1])

return mcolors.ListedColormap(colors)

0 comments on commit 8e74529

Please sign in to comment.