Skip to content

Commit

Permalink
pattern clustering update p2
Browse files Browse the repository at this point in the history
  • Loading branch information
nkempynck committed Nov 19, 2024
1 parent 9403a50 commit 9597cc5
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 75 deletions.
130 changes: 83 additions & 47 deletions docs/tutorials/enhancer_code_analysis.ipynb

Large diffs are not rendered by default.

19 changes: 11 additions & 8 deletions src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def clustermap(
fig_path: str | None = None,
pat_seqs: list[tuple[str, np.ndarray]] | None = None,
dendrogram_ratio: tuple[float, float] = (0.05, 0.2),
importance_threshold : float = 0,
) -> sns.matrix.ClusterGrid:
"""
Create a clustermap from the given pattern matrix and class labels with customizable options.
Expand Down Expand Up @@ -235,6 +236,8 @@ def clustermap(
List of sequences to use as xticklabels.
dendrogram_ratio
Ratio of dendograms in x and y directions.
importance_threshold
Minimal pattern importance threshold over all classes to retain the pattern before clustering and plotting.
See Also
--------
Expand Down Expand Up @@ -262,13 +265,13 @@ def clustermap(
pattern_matrix = pattern_matrix[subset_indices, :]
classes = [classes[i] for i in subset_indices]

# Remove columns that contain only zero values
non_zero_columns = np.any(pattern_matrix != 0, axis=0)
pattern_matrix = pattern_matrix[:, non_zero_columns]
# Filter columns based on importance threshold
max_importance = np.max(np.abs(pattern_matrix), axis=0)
above_threshold = max_importance > importance_threshold
pattern_matrix = pattern_matrix[:, above_threshold]

# Reindex columns based on the original positions of non-zero columns
column_indices = np.where(non_zero_columns)[0]
data = pd.DataFrame(pattern_matrix, columns=column_indices)
if pat_seqs is not None:
pat_seqs = [pat_seqs[i] for i in np.where(above_threshold)[0]]

data = pd.DataFrame(pattern_matrix)

Expand Down Expand Up @@ -301,7 +304,7 @@ def clustermap(

# Reorder the pat_seqs to follow the column order
if pat_seqs is not None:
reordered_pat_seqs = [pat_seqs[column_indices[i]] for i in col_order]
reordered_pat_seqs = [pat_seqs[i] for i in col_order]
ax = g.ax_heatmap
x_positions = (
np.arange(len(reordered_pat_seqs)) + 0.5
Expand Down Expand Up @@ -663,7 +666,7 @@ def clustermap_tf_motif(
ax_scatter = fig.add_subplot(111)

# Define color normalization
norm = mcolors.TwoSlopeNorm(vmin=color_data.min(), vcenter=0, vmax=color_data.max())
norm = mcolors.TwoSlopeNorm(vmin=-max(np.abs(color_data.min()), np.abs(color_data.max())), vcenter=0, vmax=max(np.abs(color_data.min()), np.abs(color_data.max())))

# Plot scatter matrix
sc = ax_scatter.scatter(
Expand Down
1 change: 1 addition & 0 deletions src/crested/tl/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from ._cosinemse import CosineMSELoss
from ._cosinemse_log import CosineMSELogLoss
from ._poisson import PoissonLoss
from ._poissonmultinomial import PoissonMultinomialLoss
3 changes: 2 additions & 1 deletion src/crested/tl/losses/_poissonmultinomial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import keras
import keras.ops as ops


@keras.saving.register_keras_serializable(package="Losses")
class PoissonMultinomialLoss(keras.losses.Loss):
"""
Expand Down Expand Up @@ -99,4 +100,4 @@ def get_config(self):
"log_input": self.log_input,
"axis": self.axis,
})
return config
return config
88 changes: 69 additions & 19 deletions src/crested/tl/modisco/_tfmodisco.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def match_to_patterns(
"""
p["id"] = pattern_id
p["pos_pattern"] = pos_pattern
p['n_seqlets'] = p['seqlets']['n_seqlets'][0]
if not all_patterns:
return add_pattern_to_dict(p, 0, cell_type, pos_pattern, all_patterns)

Expand Down Expand Up @@ -297,10 +298,10 @@ def match_to_patterns(
all_patterns[str(match_idx)]["instances"][pattern_id] = p

if cell_type in all_patterns[str(match_idx)]["classes"].keys():
ic_class_representative = all_patterns[str(match_idx)]["classes"][cell_type][
"ic"
]
ic_class_representative = all_patterns[str(match_idx)]["classes"][cell_type]["ic"]
n_seqlets_class_representative = all_patterns[str(match_idx)]["classes"][cell_type]['n_seqlets']
if p_ic > ic_class_representative:
p['n_seqlets'] = n_seqlets_class_representative if n_seqlets_class_representative > p['n_seqlets'] else p['n_seqlets'] # if a class representative for a pattern gets replaced, we keep the max seqlet count between the two of them since they are the same pattern
all_patterns[str(match_idx)]["classes"][cell_type] = p
else:
all_patterns[str(match_idx)]["classes"][cell_type] = p
Expand Down Expand Up @@ -452,20 +453,22 @@ def merge_patterns(pattern1: dict, pattern2: dict) -> dict:
for cell_type in pattern1["classes"].keys():
if cell_type in pattern2["classes"].keys():
ic_a = pattern1["classes"][cell_type]["ic"]
n_seqlets_a = pattern1["classes"][cell_type]['n_seqlets']
ic_b = pattern2["classes"][cell_type]["ic"]
n_seqlets_b = pattern2["classes"][cell_type]['n_seqlets']
merged_classes[cell_type] = (
pattern1["classes"][cell_type]
if ic_a > ic_b
else pattern2["classes"][cell_type]
)
merged_classes[cell_type]['n_seqlets'] = max(n_seqlets_a, n_seqlets_b) # if patterns from the same class get merged, we keep the max seqlet count between the two of them since they are the same pattern
else:
merged_classes[cell_type] = pattern1["classes"][cell_type]

for cell_type in pattern2["classes"].keys():
if cell_type not in merged_classes.keys():
merged_classes[cell_type] = pattern2["classes"][cell_type]

merged_classes = {**pattern1["classes"], **pattern2["classes"]}
merged_instances = {**pattern1["instances"], **pattern2["instances"]}

if pattern2["ic"] > pattern1["ic"]:
Expand Down Expand Up @@ -700,6 +703,7 @@ def create_pattern_matrix(
classes: list[str],
all_patterns: dict[str, dict[str, str | list[float]]],
normalize: bool = False,
pattern_parameter : str = 'seqlet_count'
) -> np.ndarray:
"""
Create a pattern matrix from classes and patterns, with optional normalization.
Expand All @@ -712,6 +716,8 @@ def create_pattern_matrix(
dictionary containing pattern data.
normalize
Flag to indicate whether to normalize the rows of the matrix.
pattern_parameter
Parameter which is used to indicate the pattern's importance. Either average contribution score ('contrib'), or number of pattern instances ('seqlet_count', default) and its log ('seqlet_count_log').
See Also
--------
Expand All @@ -722,15 +728,34 @@ def create_pattern_matrix(
-------
The resulting pattern matrix, optionally normalized.
"""
if pattern_parameter not in ['contrib', 'seqlet_count', 'seqlet_count_log']:
logger.info(
"Pattern parameter not valid. Setting to default ('seqlet_count')"
)
pattern_parameter = 'seqlet_count'

pattern_matrix = np.zeros((len(classes), len(all_patterns.keys())))

for p_idx in all_patterns:
p_classes = list(all_patterns[p_idx]["classes"].keys())
for ct in p_classes:
idx = np.argwhere(np.array(classes) == ct)[0][0]
pattern_matrix[idx, int(p_idx)] = np.mean(
all_patterns[p_idx]["classes"][ct]["contrib_scores"]
)
avg_contrib = np.mean(
all_patterns[p_idx]["classes"][ct]["contrib_scores"]
)
if pattern_parameter == 'contrib':
pattern_matrix[idx, int(p_idx)] = avg_contrib
elif pattern_parameter == 'seqlet_count':
sign = 1 if avg_contrib > 0 else -1 # Negative patterns will have a 'negative' count to reflect the negative performance.
pattern_matrix[idx, int(p_idx)] = sign * all_patterns[p_idx]["classes"][ct]["n_seqlets"]
elif pattern_parameter == 'seqlet_count_log':
sign = 1 if avg_contrib > 0 else -1 # Negative patterns will have a 'negative' count to reflect the negative performance.
pattern_matrix[idx, int(p_idx)] = sign * np.log1p(all_patterns[p_idx]["classes"][ct]["n_seqlets"])
else:
raise ValueError(
"Invalid pattern_parameter. Set to either 'contrib' or 'seqlet_count'."
)


# Filter out columns that are all zeros
filtered_array = pattern_matrix[:, ~np.all(pattern_matrix == 0, axis=0)]
Expand Down Expand Up @@ -1036,8 +1061,11 @@ def create_tf_ct_matrix(
df: pd.DataFrame,
classes: list[str],
log_transform: bool = True,
normalize: bool = True,
normalize_pattern_importances: bool = False,
normalize_gex: bool = False,
min_tf_gex: float = 0,
importance_threshold: float = 0,
pattern_parameter : str = "seqlet_count",
) -> tuple[np.ndarray, list[str]]:
"""
Create a tensor (matrix) of transcription factor (TF) expression and cell type contributions.
Expand All @@ -1054,10 +1082,16 @@ def create_tf_ct_matrix(
A list of cell type classes.
log_transform
Whether to apply log transformation to the gene expression values. Default is True.
normalize
Whether to normalize the contribution scores across the cell types. Default is True.
normalize_pattern_importances
Whether to normalize the contribution scores across the cell types. Default is False.
normalize_gex
Whether to normalize gene expression across the cell types. Default is False.
min_tf_gex
The minimal GEX value to select potential TF candidates. Default 0.
importance_threshold
The minimum pattern importance value. Default is 0.
pattern_parameter
Parameter which is used to indicate the pattern's importance. Either average contribution score ('contrib'), or number of pattern instances ('seqlet_count', default) and its log ('seqlet_count_log').
See Also
--------
Expand All @@ -1071,13 +1105,30 @@ def create_tf_ct_matrix(
tf_ct_matrix = np.zeros((len(classes), total_tf_patterns, 2))
tf_pattern_annots = []

if pattern_parameter not in ['contrib', 'seqlet_count', 'seqlet_count_log']:
logger.info(
"Pattern parameter not valid. Setting to default ('seqlet_count')."
)
pattern_parameter = 'seqlet_count'

counter = 0
for p_idx in pattern_tf_dict:
ct_contribs = np.zeros(len(classes))
for ct in all_patterns[p_idx]["classes"]:
idx = np.argwhere(np.array(classes) == ct)[0][0]
contribs = np.mean(all_patterns[p_idx]["classes"][ct]["contrib_scores"])
ct_contribs[idx] = contribs
if pattern_parameter == 'contrib':
ct_contribs[idx] = contribs
elif pattern_parameter =='seqlet_count':
sign = 1 if contribs > 0 else -1
ct_contribs[idx] = sign * all_patterns[p_idx]["classes"][ct]["n_seqlets"]
elif pattern_parameter =='seqlet_count_log':
sign = 1 if contribs > 0 else -1
ct_contribs[idx] = sign * np.log1p(all_patterns[p_idx]["classes"][ct]["n_seqlets"])
else:
raise ValueError(
"Invalid pattern_parameter. Set to either 'contrib' or 'seqlet_count'."
)

for tf in pattern_tf_dict[p_idx]["tfs"]:
if tf in df.columns:
Expand All @@ -1093,8 +1144,10 @@ def create_tf_ct_matrix(
tf_pattern_annots.append(tf_pattern_annot)

tf_ct_matrix = tf_ct_matrix[:, : len(tf_pattern_annots), :]
if normalize:
if normalize_pattern_importances:
tf_ct_matrix[:, :, 1] = normalize_rows(tf_ct_matrix[:, :, 1])
if normalize_gex:
tf_ct_matrix[:, :, 0] = normalize_rows(tf_ct_matrix[:, :, 0].T).T

# Logic to remove columns where tf_gex is zero for all non-zero ct_contribs
initial_columns = tf_ct_matrix.shape[1]
Expand All @@ -1104,13 +1157,11 @@ def create_tf_ct_matrix(
tf_gex_col = tf_ct_matrix[:, col, 0]
ct_contribs_col = tf_ct_matrix[:, col, 1]

# Identify non-zero ct_contribs
non_zero_contribs = ct_contribs_col != 0
# Identify relevant ct_contribs
relevant_contribs = ct_contribs_col > importance_threshold

# Check if all non-zero ct_contribs have zero tf_gex values
if np.any(non_zero_contribs) and np.any(
tf_gex_col[non_zero_contribs] > min_tf_gex
):
# Check if there are valid ct_contribs and tf_gex above the threshold
if np.any(relevant_contribs) and np.any(tf_gex_col[relevant_contribs] > min_tf_gex):
columns_to_keep.append(col)

# Convert columns_to_keep to a boolean mask
Expand All @@ -1132,7 +1183,6 @@ def create_tf_ct_matrix(

return tf_ct_matrix, tf_pattern_annots


def calculate_mean_expression_per_cell_type(
file_path: str, cell_type_column: str
) -> pd.DataFrame:
Expand Down
1 change: 1 addition & 0 deletions src/crested/tl/zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ._basenji import basenji
from ._chrombpnet import chrombpnet
from ._chrombpnet_decoupled import chrombpnet_decoupled
from ._deeptopic_cnn import deeptopic_cnn
from ._deeptopic_lstm import deeptopic_lstm
from ._simple_convnet import simple_convnet

0 comments on commit 9597cc5

Please sign in to comment.