Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use numpy histogram in scatter_density #70

Merged
merged 5 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ dynamic = ["version"]
description = "Pytometry is a Python package for flow and mass cytometry analysis."
requires-python = '>= 3.9'
dependencies = [
"nbproject",
"numpy>=1.20.0",
"numba>=0.57",
"pandas<2.0.0,>=1.5.3",
Expand All @@ -24,7 +23,6 @@ dependencies = [
"matplotlib",
"readfcs >=1.1.0",
"flowutils",
"datashader",
"consensusclustering",
"minisom"
]
Expand All @@ -41,6 +39,7 @@ dev = [
test = [
"pytest>=6.0",
"pytest-cov",
"nbproject",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I'm not sure why you need nbproject at all. For tests, you only need nbproject_test.

It's used once in the code:

from nbproject._logger import logger (https://github.com/scverse/pytometry/blob/main/tests/test_notebooks.py#L4) and this can easily be replaced with just a print or Python logger statement

"nbproject_test >= 0.2.0",
]

Expand Down
75 changes: 39 additions & 36 deletions pytometry/plotting/_scatter_density.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Literal # noqa: TYP001
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union

import datashader as ds
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from datashader.mpl_ext import dsshow
from matplotlib import colormaps
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.scale import ScaleBase
Expand All @@ -25,7 +23,8 @@ def scatter_density(
y_lim: Optional[Tuple[float, float]] = None,
ax: Optional[Axes] = None,
figsize: Optional[tuple[int, int]] = None,
cmap: Union[str, List, Colormap] = "jet",
bins: Union[int, tuple[int, int]] = 500,
cmap: Union[str, Colormap] = "jet",
vmin: Optional[float] = None,
vmax: Optional[float] = None,
*,
Expand Down Expand Up @@ -54,6 +53,8 @@ def scatter_density(
draw into an existing figure.
figsize (tuple), optional:
Figure size (width, height) if ``ax`` not provided. Defaults to (10, 10).
bins (int or tuple), optional:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a general point, but I don't see the point of typing docstrings when you're using typehints in the function stub (as you should!). It's annoying to keep both in sync. Sphinx picks up the types from the function header anyways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree, I'll open a separate issue!

Number of bins for the `np.histogram2d` function
cmap (str or list or :class:`matplotlib.colors.Colormap`), optional:
For scalar aggregates, a matplotlib colormap name or instance.
Alternatively, an iterable of colors can be passed and will be converted
Expand All @@ -69,41 +70,43 @@ def scatter_density(
Returns:
Scatter plot that displays cell density
"""
figsize = figsize if figsize is not None else (10, 10)
ax = plt.subplots(figsize=figsize)[1] if ax is None else ax
if x_label is None:
x_label = x
if y_label is None:
y_label = y
# Create df from anndata object
markers = [x, y]
joined = sc.get.obs_df(adata, keys=[*markers], layer=layer)

# Convert variables to np.array
x = np.array(joined[x])
y = np.array(joined[y])
if isinstance(bins, int):
bins = (bins, bins)

# Plot density with datashader
df = pd.DataFrame(dict(x=x, y=y))
dsartist = dsshow(
df,
ds.Point("x", "y"),
ds.count(),
vmin=vmin,
vmax=vmax,
norm=None,
# aspect="auto",
ax=ax,
cmap=cmap,
hist, xedges, yedges = np.histogram2d(
adata.obs_vector(x, layer=layer), adata.obs_vector(y, layer=layer), bins=bins
)

plt.colorbar(dsartist)
vmin = hist.min() if vmin is None else vmin
vmax = hist.max() if vmax is None else vmax

plt.xlim(x_lim)
plt.ylim(y_lim)
plt.yscale(x_scale)
plt.xscale(y_scale)
plt.xlabel(x_label)
plt.ylabel(y_label)
image = ax.imshow(
hist.T,
extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
norm=mcolors.Normalize(vmin=vmin, vmax=vmax),
cmap=_get_cmap_white_background(cmap),
aspect="auto",
origin="lower",
)
plt.colorbar(image, ax=ax)

ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_yscale(x_scale)
ax.set_xscale(y_scale)
ax.set_xlabel(x if x_label is None else x_label)
ax.set_ylabel(y if y_label is None else y_label)

plt.show()


def _get_cmap_white_background(cmap: Union[str, Colormap]) -> Colormap:
if isinstance(cmap, str):
cmap = colormaps.get_cmap(cmap)

colors = cmap(np.arange(cmap.N))
colors[0] = np.array([1, 1, 1, 1])

return mcolors.ListedColormap(colors)
Loading