Skip to content

Commit

Permalink
added first part of pattern analysis (atac only)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkempynck committed Jun 26, 2024
1 parent f4288eb commit c4acbf5
Show file tree
Hide file tree
Showing 7 changed files with 1,132 additions and 48 deletions.
292 changes: 252 additions & 40 deletions docs/tutorials/mouse_biccn.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"seaborn",
"cmake",
"modisco-lite",
"vizsequence"
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/crested/pl/patterns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._contribution_scores import contribution_scores
from ._modisco_results import modisco_results
from ._modisco_results import modisco_results, create_clustermap
107 changes: 107 additions & 0 deletions src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import modiscolite as modisco
import numpy as np
from loguru import logger
import pandas as pd
import seaborn as sns

from crested._logging import log_and_raise
from crested.pl._utils import render_plot
Expand Down Expand Up @@ -290,3 +292,108 @@ def modisco_results(
kwargs["height"] = 2 * max_num_patterns

render_plot(fig, **kwargs)

def plot_custom_xticklabels(
ax: plt.Axes,
sequences: List[Tuple[str, np.ndarray]],
col_order: List[int],
fontsize: int = 10,
dy: float = 0.012
) -> None:
"""
Plot custom x-tick labels with varying letter heights.
Parameters:
- ax (plt.Axes): The axes object to plot on.
- sequences (list): List of tuples containing sequences and their corresponding heights.
- col_order (list): List of column indices after clustering.
- fontsize (int): Base font size for the letters.
- dy (float): Vertical adjustment factor for letter heights.
"""
ax.set_xticks(np.arange(len(sequences)))
ax.set_xticklabels([])
ax.tick_params(axis='x', which='both', length=0)

for i, original_index in enumerate(col_order):
sequence, heights = sequences[original_index]
y_position = -0.02
for j, (char, height) in enumerate(zip(sequence, heights)):
char_fontsize = height * fontsize
text = ax.text(i, y_position, char, ha='center', va='center', color='black',
transform=ax.get_xaxis_transform(), fontsize=char_fontsize, rotation=270)
renderer = ax.figure.canvas.get_renderer()
char_width = text.get_window_extent(renderer=renderer).width
y_position -= dy

def create_clustermap(
pattern_matrix: np.ndarray,
classes: List[str],
figsize: Tuple[int, int] = (15, 13),
grid: bool = False,
color_palette: Union[str, List[str]] = "hsv",
cmap: str = 'coolwarm',
center: float = 0,
method: str = 'average',
fig_path: Optional[str] = None,
pat_seqs: Optional[List[Tuple[str, np.ndarray]]] = None,
dy: float = 0.012
) -> sns.matrix.ClusterGrid:
"""
Create a clustermap from the given pattern matrix and class labels with customizable options.
Parameters:
- pattern_matrix (np.ndarray): 2D NumPy array containing pattern data.
- classes (list): List of class labels.
- figsize (tuple): Size of the figure.
- grid (bool): Whether to add a grid to the heatmap.
- color_palette (str or list): Color palette for the row colors.
- cmap (str): Colormap for the clustermap.
- center (float): Value at which to center the colormap.
- method (str): Clustering method to use (e.g., 'average', 'single', 'complete').
- fig_path (str, optional): Path to save the figure.
- pat_seqs (list, optional): List of sequences to use as xticklabels.
- dy (float): Vertical adjustment factor for letter heights.
Returns:
- sns.matrix.ClusterGrid: The clustermap object.
"""
data = pd.DataFrame(pattern_matrix)

if isinstance(color_palette, str):
palette = sns.color_palette(color_palette, len(set(classes)))
else:
palette = color_palette

class_lut = dict(zip(set(classes), palette))
row_colors = pd.Series(classes).map(class_lut)

xtick_labels = False if pat_seqs is not None else True

g = sns.clustermap(
data,
cmap=cmap,
figsize=figsize,
row_colors=row_colors,
yticklabels=classes,
center=center,
xticklabels=xtick_labels,
method=method
)
col_order = g.dendrogram_col.reordered_ind

for label in class_lut:
g.ax_col_dendrogram.bar(0, 0, color=class_lut[label], label=label, linewidth=0)

if grid:
ax = g.ax_heatmap
ax.grid(True, which='both', color='grey', linewidth=0.25)
g.fig.canvas.draw()

if pat_seqs is not None:
plot_custom_xticklabels(g.ax_heatmap, pat_seqs, col_order, dy=dy)

if fig_path is not None:
plt.savefig(fig_path)

plt.show()
return g
2 changes: 1 addition & 1 deletion src/crested/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import data, losses, metrics, zoo
from ._configs import TaskConfig, default_configs
from ._crested import Crested
from ._tfmodisco import tfmodisco
from ._tfmodisco import tfmodisco, match_h5_files_to_classes, process_patterns, create_pattern_matrix, generate_nucleotide_sequences
122 changes: 122 additions & 0 deletions src/crested/tl/_modisco_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

import h5py
import html5lib
import modiscolite as modisco
from modiscolite.core import TrackSet, Seqlet, SeqletSet
import numpy as np
import os
import pandas as pd
from sklearn.decomposition import PCA
from typing import Dict, List, Tuple, Callable, Optional, Union
from vizsequence.viz_sequence import *

def l1(X: np.ndarray) -> np.ndarray:
"""
Normalizes the input array using the L1 norm.
Parameters:
- X (np.ndarray): Input array.
Returns:
- np.ndarray: L1 normalized array.
"""
abs_sum = np.sum(np.abs(X))
return X if abs_sum == 0 else (X / abs_sum)

def get_2d_data_from_patterns(
pattern: Dict,
transformer: str = 'l1',
include_hypothetical: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
"""
Gets 2D data from patterns using specified transformer.
Parameters:
- pattern (dict): Dictionary containing pattern data.
- transformer (str): Transformer function to use ('l1' or 'magnitude').
- include_hypothetical (bool): Whether to include hypothetical contributions.
Returns:
- tuple: Forward and reverse 2D data arrays.
"""
func = l1 if transformer == 'l1' else magnitude
tracks = ['hypothetical_contribs', 'contrib_scores'] if include_hypothetical else ['contrib_scores']

all_fwd_data, all_rev_data = [], []
snippets = [pattern[track] for track in tracks]

fwd_data = np.concatenate([func(snippet) for snippet in snippets], axis=1)
rev_data = np.concatenate([func(snippet[::-1, ::-1]) for snippet in snippets], axis=1)

all_fwd_data.append(fwd_data)
all_rev_data.append(rev_data)

return np.array(all_fwd_data), np.array(all_rev_data)


def pad_pattern(pattern: Dict, pad_len: int = 2) -> Dict:
"""
Pads the pattern with zeros.
Parameters:
- pattern (dict): Dictionary containing the pattern data.
- pad_len (int): Length of padding.
Returns:
- dict: Padded pattern.
"""
p0 = pattern.copy()
p0['contrib_scores'] = np.concatenate((np.zeros((pad_len, 4)), p0['contrib_scores'], np.zeros((pad_len, 4))))
p0['hypothetical_contribs'] = np.concatenate((np.zeros((pad_len, 4)), p0['hypothetical_contribs'], np.zeros((pad_len, 4))))
return p0

def match_score_patterns(a: Dict, b: Dict) -> float:
"""
Computes the match score between two patterns.
Parameters:
- a (dict): First pattern.
- b (dict): Second pattern.
Returns:
- float: Match score between the patterns.
"""
a = pad_pattern(a)
fwd_data_A, rev_data_A = get_2d_data_from_patterns(a)
fwd_data_B, rev_data_B = get_2d_data_from_patterns(b)
X = fwd_data_B if fwd_data_B.shape[1] <= fwd_data_A.shape[1] else fwd_data_A
Y = fwd_data_A if fwd_data_B.shape[1] <= fwd_data_A.shape[1] else fwd_data_B
sim_fwd_pattern = np.array(modisco.affinitymat.jaccard(X, Y).squeeze())
X = fwd_data_B if fwd_data_B.shape[1] <= fwd_data_A.shape[1] else rev_data_A
Y = rev_data_A if fwd_data_B.shape[1] <= fwd_data_A.shape[1] else fwd_data_B
sim_rev_pattern = np.array(modisco.affinitymat.jaccard(X, Y).squeeze())

return max(sim_fwd_pattern[0], sim_rev_pattern[0])

def read_html_to_dataframe(source: str):
"""
Reads an HTML table from the Modisco report function into a DataFrame.
Parameters:
- source: str - The URL or file path to the HTML content.
Returns:
- DataFrame containing the HTML table or an error message if no table is found.
"""
try:
# Attempt to read the HTML content
dfs = pd.read_html(source)

# Check if any tables were found
if not dfs:
return "No tables found in the HTML content."

# Return the first DataFrame
return dfs[0]
except ValueError as e:
# Handle the case where no tables are found
return f"Error: {str(e)}"
except Exception as e:
# Handle any other unexpected exceptions
return f"An error occurred: {str(e)}"
Loading

0 comments on commit c4acbf5

Please sign in to comment.