Skip to content

Commit

Permalink
Merge pull request #8 from aertslab/gpu-opt
Browse files Browse the repository at this point in the history
Gpu opt
  • Loading branch information
LukasMahieu authored Jul 10, 2024
2 parents 1a0571b + 6bb4410 commit 8cbb85e
Show file tree
Hide file tree
Showing 7 changed files with 638 additions and 391 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ dependencies = [
"tqdm",
"loguru",
"logomaker",
"pybigtools",
"pybigtools>=0.2.0",
"seaborn",
"pooch"
"pooch",
]

[project.optional-dependencies]
tfmodisco = [
"modisco-lite",
"vizsequence"
]

dev = [
Expand Down
22 changes: 21 additions & 1 deletion src/crested/pl/patterns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,22 @@
from loguru import logger

from ._contribution_scores import contribution_scores
from ._modisco_results import modisco_results, create_clustermap


def _optional_function_warning(*args, **kwargs):
logger.error(
"The requested functionality requires the 'tfmodisco' package, which is not installed. "
"Please install it with `pip install crested[tfmodisco]`.",
)


try:
from ._modisco_results import create_clustermap, modisco_results
except ImportError:
modisco_results = _optional_function_warning
create_clustermap = _optional_function_warning

if modisco_results is not None:
__all__ = ["contribution_scores", "modisco_results", "create_clustermap"]
else:
__all__ = ["contribution_scores"]
28 changes: 17 additions & 11 deletions src/crested/pl/patterns/_contribution_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from crested._logging import log_and_raise
from crested.pl._utils import render_plot

from ._utils import _plot_attribution_map, _plot_mutagenesis_map, grad_times_input_to_df, grad_times_input_to_df_mutagenesis
from ._utils import (
_plot_attribution_map,
_plot_mutagenesis_map,
grad_times_input_to_df,
grad_times_input_to_df_mutagenesis,
)


@log_and_raise(ValueError)
Expand Down Expand Up @@ -83,32 +88,33 @@ def contribution_scores(
start_idx = center - int(zoom_n_bases / 2)
scores = scores[:, :, start_idx : start_idx + zoom_n_bases, :]


# Plot
logger.info(f"Plotting contribution scores for {seqs_one_hot.shape[0]} sequence(s)")
for seq in range(seqs_one_hot.shape[0]):
fig_height_per_class = 2
fig = plt.figure(figsize=(50, fig_height_per_class * scores.shape[1]))
seq_class_x = seqs_one_hot[seq, start_idx : start_idx + zoom_n_bases, :]

if method == 'mutagenesis':
global_max = scores[seq].max()+0.25*np.abs(scores[seq].max())
global_min = scores[seq].min()-0.25*np.abs(scores[seq].min())
if method == "mutagenesis":
global_max = scores[seq].max() + 0.25 * np.abs(scores[seq].max())
global_min = scores[seq].min() - 0.25 * np.abs(scores[seq].min())
else:
mins = []
maxs = []
for i in range(scores.shape[1]):
seq_class_scores = scores[seq, i, :, :]
mins.append(np.min(seq_class_scores*seq_class_x))
maxs.append(np.max(seq_class_scores*seq_class_x))
global_max = np.array(maxs).max()+0.25*np.abs(np.array(maxs).max())
global_min = np.array(mins).min()-0.25*np.abs(np.array(mins).min())
mins.append(np.min(seq_class_scores * seq_class_x))
maxs.append(np.max(seq_class_scores * seq_class_x))
global_max = np.array(maxs).max() + 0.25 * np.abs(np.array(maxs).max())
global_min = np.array(mins).min() - 0.25 * np.abs(np.array(mins).min())

for i in range(scores.shape[1]):
seq_class_scores = scores[seq, i, :, :]
ax = plt.subplot(scores.shape[1], 1, i + 1)
if (method =='mutagenesis'):
mutagenesis_df = grad_times_input_to_df_mutagenesis(seq_class_x, seq_class_scores)
if method == "mutagenesis":
mutagenesis_df = grad_times_input_to_df_mutagenesis(
seq_class_x, seq_class_scores
)
_plot_mutagenesis_map(mutagenesis_df, ax=ax)
else:
intgrad_df = grad_times_input_to_df(seq_class_x, seq_class_scores)
Expand Down
130 changes: 77 additions & 53 deletions src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import matplotlib.pyplot as plt
import modiscolite as modisco
import numpy as np
from loguru import logger
import pandas as pd
import seaborn as sns
from loguru import logger

from crested._logging import log_and_raise
from crested.pl._utils import render_plot
Expand Down Expand Up @@ -47,7 +47,7 @@ def _trim_pattern_by_ic(
contrib_scores = np.array(pattern["contrib_scores"])
if not pos_pattern:
contrib_scores = -contrib_scores
contrib_scores[contrib_scores < 0] = 1e-9 # avoid division by zero
contrib_scores[contrib_scores < 0] = 1e-9 # avoid division by zero

ic = modisco.util.compute_per_position_ic(
ppm=np.array(contrib_scores), background=background, pseudocount=pseudocount
Expand Down Expand Up @@ -293,102 +293,126 @@ def modisco_results(

render_plot(fig, **kwargs)


def plot_custom_xticklabels(
ax: plt.Axes,
sequences: List[Tuple[str, np.ndarray]],
col_order: List[int],
fontsize: int = 10,
dy: float = 0.012
ax: plt.Axes,
sequences: list[tuple[str, np.ndarray]],
col_order: list[int],
fontsize: int = 10,
dy: float = 0.012,
) -> None:
"""
Plot custom x-tick labels with varying letter heights.
Parameters:
- ax (plt.Axes): The axes object to plot on.
- sequences (list): List of tuples containing sequences and their corresponding heights.
- col_order (list): List of column indices after clustering.
- fontsize (int): Base font size for the letters.
- dy (float): Vertical adjustment factor for letter heights.
Parameters
----------
ax
The axes object to plot on.
sequences
List of tuples containing sequences and their corresponding heights.
col_order
List of column indices after clustering.
fontsize
Base font size for the letters.
dy
Vertical adjustment factor for letter heights.
"""
ax.set_xticks(np.arange(len(sequences)))
ax.set_xticklabels([])
ax.tick_params(axis='x', which='both', length=0)
ax.tick_params(axis="x", which="both", length=0)

for i, original_index in enumerate(col_order):
sequence, heights = sequences[original_index]
y_position = -0.02
for j, (char, height) in enumerate(zip(sequence, heights)):
for _, (char, height) in enumerate(zip(sequence, heights)):
char_fontsize = height * fontsize
text = ax.text(i, y_position, char, ha='center', va='center', color='black',
transform=ax.get_xaxis_transform(), fontsize=char_fontsize, rotation=270)
text = ax.text(
i,
y_position,
char,
ha="center",
va="center",
color="black",
transform=ax.get_xaxis_transform(),
fontsize=char_fontsize,
rotation=270,
)
renderer = ax.figure.canvas.get_renderer()
char_width = text.get_window_extent(renderer=renderer).width
_ = text.get_window_extent(renderer=renderer).width
y_position -= dy


def create_clustermap(
pattern_matrix: np.ndarray,
classes: List[str],
figsize: Tuple[int, int] = (15, 13),
grid: bool = False,
color_palette: Union[str, List[str]] = "hsv",
cmap: str = 'coolwarm',
center: float = 0,
method: str = 'average',
fig_path: Optional[str] = None,
pat_seqs: Optional[List[Tuple[str, np.ndarray]]] = None,
dy: float = 0.012
pattern_matrix: np.ndarray,
classes: list[str],
figsize: tuple[int, int] = (15, 13),
grid: bool = False,
color_palette: str | list[str] = "hsv",
cmap: str = "coolwarm",
center: float = 0,
method: str = "average",
fig_path: str | None = None,
pat_seqs: list[tuple[str, np.ndarray]] | None = None,
dy: float = 0.012,
) -> sns.matrix.ClusterGrid:
"""
Create a clustermap from the given pattern matrix and class labels with customizable options.
Parameters:
- pattern_matrix (np.ndarray): 2D NumPy array containing pattern data.
- classes (list): List of class labels.
- figsize (tuple): Size of the figure.
- grid (bool): Whether to add a grid to the heatmap.
- color_palette (str or list): Color palette for the row colors.
Parameters
----------
pattern_matrix
2D NumPy array containing pattern data.
classes
List of class labels.
figsize
Size of the figure.
grid
Whether to add a grid to the heatmap.
color_palette
Color palette for the row colors.
- cmap (str): Colormap for the clustermap.
- center (float): Value at which to center the colormap.
- method (str): Clustering method to use (e.g., 'average', 'single', 'complete').
- fig_path (str, optional): Path to save the figure.
- pat_seqs (list, optional): List of sequences to use as xticklabels.
- dy (float): Vertical adjustment factor for letter heights.
Returns:
- sns.matrix.ClusterGrid: The clustermap object.
Returns
-------
The clustermap object.
"""
data = pd.DataFrame(pattern_matrix)

if isinstance(color_palette, str):
palette = sns.color_palette(color_palette, len(set(classes)))
else:
palette = color_palette

class_lut = dict(zip(set(classes), palette))
row_colors = pd.Series(classes).map(class_lut)

xtick_labels = False if pat_seqs is not None else True

g = sns.clustermap(
data,
cmap=cmap,
figsize=figsize,
row_colors=row_colors,
yticklabels=classes,
center=center,
xticklabels=xtick_labels,
method=method
data,
cmap=cmap,
figsize=figsize,
row_colors=row_colors,
yticklabels=classes,
center=center,
xticklabels=xtick_labels,
method=method,
)
col_order = g.dendrogram_col.reordered_ind

for label in class_lut:
g.ax_col_dendrogram.bar(0, 0, color=class_lut[label], label=label, linewidth=0)

if grid:
ax = g.ax_heatmap
ax.grid(True, which='both', color='grey', linewidth=0.25)
ax.grid(True, which="both", color="grey", linewidth=0.25)
g.fig.canvas.draw()

if pat_seqs is not None:
plot_custom_xticklabels(g.ax_heatmap, pat_seqs, col_order, dy=dy)

Expand Down
53 changes: 52 additions & 1 deletion src/crested/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,55 @@
from loguru import logger

from . import data, losses, metrics, zoo
from ._configs import TaskConfig, default_configs
from ._crested import Crested
from ._tfmodisco import tfmodisco, match_h5_files_to_classes, process_patterns, create_pattern_matrix, generate_nucleotide_sequences


def _optional_function_warning(*args, **kwargs):
logger.error(
"The requested functionality requires the 'tfmodisco' package, which is not installed. "
"Please install it with `pip install crested[tfmodisco]`.",
)


# Conditional import for tfmodisco since optional
try:
from ._tfmodisco import (
create_pattern_matrix,
generate_nucleotide_sequences,
match_h5_files_to_classes,
process_patterns,
tfmodisco,
)
except ImportError:
create_pattern_matrix = _optional_function_warning
generate_nucleotide_sequences = _optional_function_warning
match_h5_files_to_classes = _optional_function_warning
process_patterns = _optional_function_warning
tfmodisco = _optional_function_warning

if tfmodisco is not None:
__all__ = [
"data",
"losses",
"metrics",
"zoo",
"TaskConfig",
"default_configs",
"Crested",
"create_pattern_matrix",
"generate_nucleotide_sequences",
"match_h5_files_to_classes",
"process_patterns",
"tfmodisco",
]
else:
__all__ = [
"data",
"losses",
"metrics",
"zoo",
"TaskConfig",
"default_configs",
"Crested",
]
6 changes: 1 addition & 5 deletions src/crested/tl/_modisco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import modiscolite as modisco
import numpy as np
import pandas as pd
from vizsequence.viz_sequence import *


def l1(X: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -39,7 +38,7 @@ def get_2d_data_from_patterns(
-------
Forward and reverse 2D data arrays.
"""
func = l1 if transformer == "l1" else magnitude
func = l1 if transformer == "l1" else None # magnitude not defined?
tracks = (
["hypothetical_contribs", "contrib_scores"]
if include_hypothetical
Expand Down Expand Up @@ -138,6 +137,3 @@ def read_html_to_dataframe(source: str):
except ValueError as e:
# Handle the case where no tables are found
return f"Error: {str(e)}"
except Exception as e:
# Handle any other unexpected exceptions
return f"An error occurred: {str(e)}"
Loading

0 comments on commit 8cbb85e

Please sign in to comment.