Skip to content

Commit

Permalink
use numpy histogram in scatter_density
Browse files Browse the repository at this point in the history
  • Loading branch information
quentinblampey committed May 17, 2024
1 parent aa44168 commit 935d578
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 35 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Changelog

All notable changes to this project will be documented in this file.

## [0.1.6] - 2024-xx-xx

### Changed
- `pm.pl.scatter_density` now uses `np.histogram2d`
- `datashader` dependency has been removed
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies = [
"matplotlib",
"readfcs >=1.1.0",
"flowutils",
"datashader",
"consensusclustering",
"minisom"
]
Expand Down
70 changes: 36 additions & 34 deletions pytometry/plotting/_scatter_density.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import Literal # noqa: TYP001
from typing import List, Optional, Tuple, Union

import datashader as ds
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from datashader.mpl_ext import dsshow
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.scale import ScaleBase
Expand All @@ -25,6 +22,7 @@ def scatter_density(
y_lim: Optional[Tuple[float, float]] = None,
ax: Optional[Axes] = None,
figsize: Optional[tuple[int, int]] = None,
bins: Union[int, tuple[int, int]] = 500,
cmap: Union[str, List, Colormap] = "jet",
vmin: Optional[float] = None,
vmax: Optional[float] = None,
Expand Down Expand Up @@ -54,6 +52,8 @@ def scatter_density(
draw into an existing figure.
figsize (tuple), optional:
Figure size (width, height) if ``ax`` not provided. Defaults to (10, 10).
bins (int or tuple), optional:
Number of bins for the `np.histogram2d` function
cmap (str or list or :class:`matplotlib.colors.Colormap`), optional:
For scalar aggregates, a matplotlib colormap name or instance.
Alternatively, an iterable of colors can be passed and will be converted
Expand All @@ -69,41 +69,43 @@ def scatter_density(
Returns:
Scatter plot that displays cell density
"""
figsize = figsize if figsize is not None else (10, 10)
ax = plt.subplots(figsize=figsize)[1] if ax is None else ax
if x_label is None:
x_label = x
if y_label is None:
y_label = y
# Create df from anndata object
markers = [x, y]
joined = sc.get.obs_df(adata, keys=[*markers], layer=layer)

# Convert variables to np.array
x = np.array(joined[x])
y = np.array(joined[y])
if isinstance(bins, int):
bins = (bins, bins)

# Plot density with datashader
df = pd.DataFrame(dict(x=x, y=y))
dsartist = dsshow(
df,
ds.Point("x", "y"),
ds.count(),
vmin=vmin,
vmax=vmax,
norm=None,
# aspect="auto",
ax=ax,
cmap=cmap,
hist, xedges, yedges = np.histogram2d(
adata.obs_vector(x, layer=layer), adata.obs_vector(y, layer=layer), bins=bins
)

plt.colorbar(dsartist)
vmin = hist.min() if vmin is None else vmin
vmax = hist.max() if vmax is None else vmax

plt.xlim(x_lim)
plt.ylim(y_lim)
plt.yscale(x_scale)
plt.xscale(y_scale)
plt.xlabel(x_label)
plt.ylabel(y_label)
image = ax.imshow(
hist.T,
extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
norm=mcolors.Normalize(vmin=vmin, vmax=vmax),
cmap=_get_cmap_white_background(cmap),
aspect="auto",
origin="lower",
)
plt.colorbar(image, ax=ax)

ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_yscale(x_scale)
ax.set_xscale(y_scale)
ax.set_xlabel(x if x_label is None else x_label)
ax.set_ylabel(y if y_label is None else y_label)

plt.show()


def _get_cmap_white_background(cmap: Union[str, List, Colormap]) -> Colormap:
if isinstance(cmap, str):
cmap = plt.cm.get_cmap(cmap)

colors = cmap(np.arange(cmap.N))
colors[0] = np.array([1, 1, 1, 1])

return mcolors.ListedColormap(colors)

0 comments on commit 935d578

Please sign in to comment.