Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/aertslab/CREsted
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMahieu committed Jul 10, 2024
2 parents 7a0e080 + 9e68b6d commit 75d5ce7
Show file tree
Hide file tree
Showing 8 changed files with 664 additions and 406 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ dependencies = [
"tqdm",
"loguru",
"logomaker",
"pybigtools",
"pybigtools>=0.2.0",
"seaborn",
"pooch"
"pooch",
]

[project.optional-dependencies]
tfmodisco = [
"modisco-lite",
"vizsequence"
]

dev = [
Expand Down
41 changes: 26 additions & 15 deletions src/crested/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,29 +64,41 @@ def _read_chromsizes(chromsizes_file: PathLike) -> dict[str, int]:
return chromsizes_dict


def _extract_values_from_bigwig(bw_file, bed_file, target, target_region_width):
def _extract_values_from_bigwig(
bw_file: PathLike, bed_file: PathLike, target: str
) -> np.ndarray:
"""Extract target values from a bigWig file for regions specified in a BED file."""
if isinstance(bed_file, Path):
bed_file = str(bed_file)
if isinstance(bw_file, Path):
bw_file = str(bw_file)

if target == "mean":
values = list(
pybigtools.bigWigAverageOverBed(bw_file, bed=bed_file, names=None)
)
with pybigtools.open(bw_file, "r") as bw:
values = np.fromiter(
bw.average_over_bed(bed=bed_file, names=None, stats="mean0"),
dtype=np.float32,
)
elif target == "max":
values = list(
pybigtools.bigWigAverageOverBed(bw_file, bed=bed_file, names=None)
)
with pybigtools.open(bw_file, "r") as bw:
values = np.fromiter(
bw.average_over_bed(bed=bed_file, names=None, stats="max"),
dtype=np.float32,
)
elif target == "count":
values = list(
pybigtools.bigWigAverageOverBed(bw_file, bed=bed_file, names=None)
)
with pybigtools.open(bw_file, "r") as bw:
values = np.fromiter(
bw.average_over_bed(bed=bed_file, names=None, stats="sum"),
dtype=np.float32,
)
elif target == "logcount":
values = list(
pybigtools.bigWigAverageOverBed(bw_file, bed=bed_file, names=None)
)
with pybigtools.open(bw_file, "r") as bw:
values = np.log1p(
np.fromiter(
bw.average_over_bed(bed=bed_file, names=None, stats="sum"),
dtype=np.float32,
)
)
else:
raise ValueError(f"Unsupported target '{target}'")

Expand Down Expand Up @@ -427,7 +439,6 @@ def import_bigwigs(
bw_file,
bed_file,
target,
target_region_width,
)
for bw_file in bw_files
]
Expand All @@ -437,7 +448,7 @@ def import_bigwigs(
if target_region_width is not None:
os.remove(bed_file)

data_matrix = np.array(all_results)
data_matrix = np.vstack(all_results)

# Create DataFrame for AnnData
df = pd.DataFrame(
Expand Down
22 changes: 21 additions & 1 deletion src/crested/pl/patterns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,22 @@
from loguru import logger

from ._contribution_scores import contribution_scores
from ._modisco_results import modisco_results, create_clustermap


def _optional_function_warning(*args, **kwargs):
logger.error(
"The requested functionality requires the 'tfmodisco' package, which is not installed. "
"Please install it with `pip install crested[tfmodisco]`.",
)


try:
from ._modisco_results import create_clustermap, modisco_results
except ImportError:
modisco_results = _optional_function_warning
create_clustermap = _optional_function_warning

if modisco_results is not None:
__all__ = ["contribution_scores", "modisco_results", "create_clustermap"]
else:
__all__ = ["contribution_scores"]
28 changes: 17 additions & 11 deletions src/crested/pl/patterns/_contribution_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from crested._logging import log_and_raise
from crested.pl._utils import render_plot

from ._utils import _plot_attribution_map, _plot_mutagenesis_map, grad_times_input_to_df, grad_times_input_to_df_mutagenesis
from ._utils import (
_plot_attribution_map,
_plot_mutagenesis_map,
grad_times_input_to_df,
grad_times_input_to_df_mutagenesis,
)


@log_and_raise(ValueError)
Expand Down Expand Up @@ -83,32 +88,33 @@ def contribution_scores(
start_idx = center - int(zoom_n_bases / 2)
scores = scores[:, :, start_idx : start_idx + zoom_n_bases, :]


# Plot
logger.info(f"Plotting contribution scores for {seqs_one_hot.shape[0]} sequence(s)")
for seq in range(seqs_one_hot.shape[0]):
fig_height_per_class = 2
fig = plt.figure(figsize=(50, fig_height_per_class * scores.shape[1]))
seq_class_x = seqs_one_hot[seq, start_idx : start_idx + zoom_n_bases, :]

if method == 'mutagenesis':
global_max = scores[seq].max()+0.25*np.abs(scores[seq].max())
global_min = scores[seq].min()-0.25*np.abs(scores[seq].min())
if method == "mutagenesis":
global_max = scores[seq].max() + 0.25 * np.abs(scores[seq].max())
global_min = scores[seq].min() - 0.25 * np.abs(scores[seq].min())
else:
mins = []
maxs = []
for i in range(scores.shape[1]):
seq_class_scores = scores[seq, i, :, :]
mins.append(np.min(seq_class_scores*seq_class_x))
maxs.append(np.max(seq_class_scores*seq_class_x))
global_max = np.array(maxs).max()+0.25*np.abs(np.array(maxs).max())
global_min = np.array(mins).min()-0.25*np.abs(np.array(mins).min())
mins.append(np.min(seq_class_scores * seq_class_x))
maxs.append(np.max(seq_class_scores * seq_class_x))
global_max = np.array(maxs).max() + 0.25 * np.abs(np.array(maxs).max())
global_min = np.array(mins).min() - 0.25 * np.abs(np.array(mins).min())

for i in range(scores.shape[1]):
seq_class_scores = scores[seq, i, :, :]
ax = plt.subplot(scores.shape[1], 1, i + 1)
if (method =='mutagenesis'):
mutagenesis_df = grad_times_input_to_df_mutagenesis(seq_class_x, seq_class_scores)
if method == "mutagenesis":
mutagenesis_df = grad_times_input_to_df_mutagenesis(
seq_class_x, seq_class_scores
)
_plot_mutagenesis_map(mutagenesis_df, ax=ax)
else:
intgrad_df = grad_times_input_to_df(seq_class_x, seq_class_scores)
Expand Down
130 changes: 77 additions & 53 deletions src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import matplotlib.pyplot as plt
import modiscolite as modisco
import numpy as np
from loguru import logger
import pandas as pd
import seaborn as sns
from loguru import logger

from crested._logging import log_and_raise
from crested.pl._utils import render_plot
Expand Down Expand Up @@ -47,7 +47,7 @@ def _trim_pattern_by_ic(
contrib_scores = np.array(pattern["contrib_scores"])
if not pos_pattern:
contrib_scores = -contrib_scores
contrib_scores[contrib_scores < 0] = 1e-9 # avoid division by zero
contrib_scores[contrib_scores < 0] = 1e-9 # avoid division by zero

ic = modisco.util.compute_per_position_ic(
ppm=np.array(contrib_scores), background=background, pseudocount=pseudocount
Expand Down Expand Up @@ -293,102 +293,126 @@ def modisco_results(

render_plot(fig, **kwargs)


def plot_custom_xticklabels(
ax: plt.Axes,
sequences: List[Tuple[str, np.ndarray]],
col_order: List[int],
fontsize: int = 10,
dy: float = 0.012
ax: plt.Axes,
sequences: list[tuple[str, np.ndarray]],
col_order: list[int],
fontsize: int = 10,
dy: float = 0.012,
) -> None:
"""
Plot custom x-tick labels with varying letter heights.
Parameters:
- ax (plt.Axes): The axes object to plot on.
- sequences (list): List of tuples containing sequences and their corresponding heights.
- col_order (list): List of column indices after clustering.
- fontsize (int): Base font size for the letters.
- dy (float): Vertical adjustment factor for letter heights.
Parameters
----------
ax
The axes object to plot on.
sequences
List of tuples containing sequences and their corresponding heights.
col_order
List of column indices after clustering.
fontsize
Base font size for the letters.
dy
Vertical adjustment factor for letter heights.
"""
ax.set_xticks(np.arange(len(sequences)))
ax.set_xticklabels([])
ax.tick_params(axis='x', which='both', length=0)
ax.tick_params(axis="x", which="both", length=0)

for i, original_index in enumerate(col_order):
sequence, heights = sequences[original_index]
y_position = -0.02
for j, (char, height) in enumerate(zip(sequence, heights)):
for _, (char, height) in enumerate(zip(sequence, heights)):
char_fontsize = height * fontsize
text = ax.text(i, y_position, char, ha='center', va='center', color='black',
transform=ax.get_xaxis_transform(), fontsize=char_fontsize, rotation=270)
text = ax.text(
i,
y_position,
char,
ha="center",
va="center",
color="black",
transform=ax.get_xaxis_transform(),
fontsize=char_fontsize,
rotation=270,
)
renderer = ax.figure.canvas.get_renderer()
char_width = text.get_window_extent(renderer=renderer).width
_ = text.get_window_extent(renderer=renderer).width
y_position -= dy


def create_clustermap(
pattern_matrix: np.ndarray,
classes: List[str],
figsize: Tuple[int, int] = (15, 13),
grid: bool = False,
color_palette: Union[str, List[str]] = "hsv",
cmap: str = 'coolwarm',
center: float = 0,
method: str = 'average',
fig_path: Optional[str] = None,
pat_seqs: Optional[List[Tuple[str, np.ndarray]]] = None,
dy: float = 0.012
pattern_matrix: np.ndarray,
classes: list[str],
figsize: tuple[int, int] = (15, 13),
grid: bool = False,
color_palette: str | list[str] = "hsv",
cmap: str = "coolwarm",
center: float = 0,
method: str = "average",
fig_path: str | None = None,
pat_seqs: list[tuple[str, np.ndarray]] | None = None,
dy: float = 0.012,
) -> sns.matrix.ClusterGrid:
"""
Create a clustermap from the given pattern matrix and class labels with customizable options.
Parameters:
- pattern_matrix (np.ndarray): 2D NumPy array containing pattern data.
- classes (list): List of class labels.
- figsize (tuple): Size of the figure.
- grid (bool): Whether to add a grid to the heatmap.
- color_palette (str or list): Color palette for the row colors.
Parameters
----------
pattern_matrix
2D NumPy array containing pattern data.
classes
List of class labels.
figsize
Size of the figure.
grid
Whether to add a grid to the heatmap.
color_palette
Color palette for the row colors.
- cmap (str): Colormap for the clustermap.
- center (float): Value at which to center the colormap.
- method (str): Clustering method to use (e.g., 'average', 'single', 'complete').
- fig_path (str, optional): Path to save the figure.
- pat_seqs (list, optional): List of sequences to use as xticklabels.
- dy (float): Vertical adjustment factor for letter heights.
Returns:
- sns.matrix.ClusterGrid: The clustermap object.
Returns
-------
The clustermap object.
"""
data = pd.DataFrame(pattern_matrix)

if isinstance(color_palette, str):
palette = sns.color_palette(color_palette, len(set(classes)))
else:
palette = color_palette

class_lut = dict(zip(set(classes), palette))
row_colors = pd.Series(classes).map(class_lut)

xtick_labels = False if pat_seqs is not None else True

g = sns.clustermap(
data,
cmap=cmap,
figsize=figsize,
row_colors=row_colors,
yticklabels=classes,
center=center,
xticklabels=xtick_labels,
method=method
data,
cmap=cmap,
figsize=figsize,
row_colors=row_colors,
yticklabels=classes,
center=center,
xticklabels=xtick_labels,
method=method,
)
col_order = g.dendrogram_col.reordered_ind

for label in class_lut:
g.ax_col_dendrogram.bar(0, 0, color=class_lut[label], label=label, linewidth=0)

if grid:
ax = g.ax_heatmap
ax.grid(True, which='both', color='grey', linewidth=0.25)
ax.grid(True, which="both", color="grey", linewidth=0.25)
g.fig.canvas.draw()

if pat_seqs is not None:
plot_custom_xticklabels(g.ax_heatmap, pat_seqs, col_order, dy=dy)

Expand Down
Loading

0 comments on commit 75d5ce7

Please sign in to comment.