Skip to content

Commit

Permalink
Merge pull request #15 from aertslab/custom_loss_function_in_silico_e…
Browse files Browse the repository at this point in the history
…volution

custom loss function enhancer design
  • Loading branch information
SeppeDeWinter authored Sep 27, 2024
2 parents 982e25b + 095d753 commit 81ba818
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 63 deletions.
230 changes: 168 additions & 62 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
from anndata import AnnData
from loguru import logger
from tqdm import tqdm
from typing import Callable, Any
from pysam import FastaFile


from crested._logging import log_and_raise
from crested.tl import TaskConfig
from crested.tl._utils import (
_weighted_difference,
EnhancerOptimizer,
generate_motif_insertions,
generate_mutagenesis,
hot_encoding_to_sequence,
Expand Down Expand Up @@ -1139,27 +1142,33 @@ def tfmodisco_calculate_and_save_contribution_scores(

def enhancer_design_motif_implementation(
self,
target_class: str,
n_sequences: int,
patterns: dict,
target_class: str | None = None,
target: int | np.ndarray | None = None,
insertions_per_pattern: dict | None = None,
return_intermediate: bool = False,
class_penalty_weights: np.ndarray | None = None,
no_mutation_flanks: tuple | None = None,
target_len: int | None = None,
preserve_inserted_motifs: bool = True,
) -> tuple[list(dict), list] | list:
enhancer_optimizer: EnhancerOptimizer | None = None,
**kwargs: dict[str, Any]
) -> tuple[list[dict], list] | list:
"""
Create synthetic enhancers for a specified class using motif implementation.
Parameters
----------
target_class
Class name for which the enhancers will be designed for.
n_sequences
Number of enhancers to design.
patterns
Dictionary of patterns to be implemented in the form of 'pattern_name':'pattern_sequence'
target_class
Class name for which the enhancers will be designed for. If this value is set to None
target needs to be specified.
target
target index, needs to be specified when target_class is None
insertions_per_pattern
Dictionary of number of patterns to be implemented in the form of 'pattern_name':number_of_insertions
If not used one of each pattern in patterns will be implemented.
Expand All @@ -1176,17 +1185,33 @@ def enhancer_design_motif_implementation(
is supplied.
preserve_inserted_motifs
If True, sequentially inserted motifs can't be inserted on previous motifs.
enhancer_optimizer
An instance of EnhancerOptimizer, defining how sequences should be optimized.
If None, a default EnhancerOptimizer will be initialized using `_weighted_difference`
as optimization function.
kwargs
Keyword arguments that will be passed to the `get_best` function of the EnhancerOptimizer
Returns
-------
A list of designed sequences and if return_intermediate is True a list of dictionaries of intermediate
mutations and predictions
"""
self._check_contribution_scores_params([target_class])
if target_class is not None:
self._check_contribution_scores_params([target_class])

all_class_names = list(self.anndatamodule.adata.obs_names)
all_class_names = list(self.anndatamodule.adata.obs_names)

target = all_class_names.index(target_class)

target = all_class_names.index(target_class)
elif target is None:
raise ValueError("`target` need to be specified when `target_class` is None")


if enhancer_optimizer is None:
enhancer_optimizer = EnhancerOptimizer(
optimize_func = _weighted_difference
)

# get input sequence length of the model
seq_len = (
Expand Down Expand Up @@ -1264,11 +1289,11 @@ def enhancer_design_motif_implementation(
mutagenesis_predictions = self.model.predict(mutagenesis)

# determine the best insertion site
best_mutation = _weighted_difference(
mutagenesis_predictions,
current_prediction,
target,
class_penalty_weights,
best_mutation = enhancer_optimizer.get_best(
mutated_predictions = mutagenesis_predictions,
original_prediction = current_prediction,
target = target,
**kwargs
)

sequence_onehot = mutagenesis[best_mutation : best_mutation + 1]
Expand Down Expand Up @@ -1305,47 +1330,67 @@ def enhancer_design_motif_implementation(

def enhancer_design_in_silico_evolution(
self,
target_class: str,
n_mutations: int,
n_sequences: int,
target_class: str | None = None,
target: int | np.ndarray | None = None,
return_intermediate: bool = False,
class_penalty_weights: np.ndarray | None = None,
no_mutation_flanks: tuple | None = None,
target_len: int | None = None,
) -> tuple[list(dict), list] | list:
enhancer_optimizer: EnhancerOptimizer | None = None,
**kwargs: dict[str, Any]
) -> tuple[list[dict], list] | list:
"""
Create synthetic enhancers for a specified class using in silico evolution (ISE).
Parameters
----------
target_class
Class name for which the enhancers will be designed for.
n_mutations
Number of mutations per sequence
Number of iterations
n_sequences
Number of enhancers to design
target_class
Class name for which the enhancers will be designed for. If this value is set to None
target needs to be specified.
target
target index, needs to be specified when target_class is None
return_intermediate
If True, returns a dictionary with predictions and changes made in intermediate steps for selected
sequences
class_penalty_weights
Array with a value per class, determining the penalty weight for that class to be used in scoring
function for sequence selection.
no_mutation_flanks
A tuple of integers which determine the regions in each flank to not do implementations.
target_len
Length of the area in the center of the sequence to make implementations, ignored if no_mutation_flanks
is supplied.
enhancer_optimizer
An instance of EnhancerOptimizer, defining how sequences should be optimized.
If None, a default EnhancerOptimizer will be initialized using `_weighted_difference`
as optimization function.
kwargs
Keyword arguments that will be passed to the `get_best` function of the EnhancerOptimizer
Returns
-------
A list of designed sequences and if return_intermediate is True a list of dictionaries of intermediate
mutations and predictions
"""
self._check_contribution_scores_params([target_class])
if self.model is None:
raise ValueError("Model should be loaded first!")

all_class_names = list(self.anndatamodule.adata.obs_names)
if target_class is not None:
self._check_contribution_scores_params([target_class])

all_class_names = list(self.anndatamodule.adata.obs_names)

target = all_class_names.index(target_class)

target = all_class_names.index(target_class)
elif target is None:
raise ValueError("`target` need to be specified when `target_class` is None")

if enhancer_optimizer is None:
enhancer_optimizer = EnhancerOptimizer(
optimize_func = _weighted_difference
)

# get input sequence length of the model
seq_len = (
Expand Down Expand Up @@ -1378,58 +1423,119 @@ def enhancer_design_in_silico_evolution(
n_sequences=n_sequences, seq_len=seq_len
)

designed_sequences = []
intermediate_info_list = []
# initialize
designed_sequences: list[str] = []
intermediate_info_list: list[dict] = []

for idx, sequence in enumerate(random_sequences):
sequence_onehot = one_hot_encode_sequence(sequence)
if return_intermediate:
intermediate_info_list.append(
{
"inital_sequence": sequence,
"changes": [(-1, "N")],
"predictions": [
self.model.predict(sequence_onehot, verbose=False)
],
"designed_sequence": "",
}
)
sequence_onehot_prev_iter = np.zeros(
(n_sequences, seq_len, 4),
dtype=np.uint8
)

# sequentially do mutations
for _mutation_step in range(n_mutations):
current_prediction = self.model.predict(sequence_onehot, verbose=False)
# calculate total number of mutations per sequence
_, L, A = sequence_onehot_prev_iter.shape
start, end = 0, L
start = no_mutation_flanks[0]
end = L - no_mutation_flanks[1]
TOTAL_NUMBER_OF_MUTATIONS_PER_SEQ = (end - start) * (A - 1)

# do every possible mutation
mutagenesis = generate_mutagenesis(
sequence_onehot, include_original=False, flanks=no_mutation_flanks
mutagenesis = np.zeros(
(n_sequences, TOTAL_NUMBER_OF_MUTATIONS_PER_SEQ, seq_len, 4)
)

for i, sequence in enumerate(random_sequences):
sequence_onehot_prev_iter[i] = one_hot_encode_sequence(sequence)

for _iter in tqdm(range(n_mutations)):
baseline_prediction = self.model.predict(
sequence_onehot_prev_iter,
verbose = False
)

if _iter == 0 :
for i in range(n_sequences):
# initialize info
intermediate_info_list.append(
{
"inital_sequence": hot_encoding_to_sequence(
sequence_onehot_prev_iter[i]
),
"changes": [(-1, "N")],
"predictions": [
baseline_prediction[i]
],
"designed_sequence": "",
}
)

# do all possible mutations
for i in range(n_sequences):
mutagenesis[i] = generate_mutagenesis(
sequence_onehot_prev_iter[i: i+1],
include_original=False, flanks=no_mutation_flanks
)
mutagenesis_predictions = self.model.predict(mutagenesis)

# determine the best mutation
best_mutation = _weighted_difference(
mutagenesis_predictions,
current_prediction,
target,
class_penalty_weights,

mutagenesis_predictions = self.model.predict(
mutagenesis.reshape(
(n_sequences * TOTAL_NUMBER_OF_MUTATIONS_PER_SEQ, seq_len, 4)
)
)

sequence_onehot = mutagenesis[best_mutation : best_mutation + 1]
mutagenesis_predictions = mutagenesis_predictions.reshape(
(
n_sequences,
TOTAL_NUMBER_OF_MUTATIONS_PER_SEQ,
mutagenesis_predictions.shape[1]
)
)

for i in range(n_sequences):
best_mutation = enhancer_optimizer.get_best(
mutated_predictions = mutagenesis_predictions[i],
original_prediction = baseline_prediction[i],
target = target,
**kwargs
)
sequence_onehot_prev_iter[i] = mutagenesis[
i,
best_mutation : best_mutation + 1,
:
]
if return_intermediate:
mutation_index = best_mutation // 3 + no_mutation_flanks[0]
changed_to = sequence_onehot[0, mutation_index, :]
intermediate_info_list[idx]["changes"].append(
(mutation_index, hot_encoding_to_sequence(changed_to))
changed_to = hot_encoding_to_sequence(
sequence_onehot_prev_iter[i, mutation_index, :]
)
intermediate_info_list[idx]["predictions"].append(
mutagenesis_predictions[best_mutation]
intermediate_info_list[i]["changes"].append(
(mutation_index, changed_to)
)
intermediate_info_list[i]["predictions"].append(
mutagenesis_predictions[i][best_mutation]
)

designed_sequence = hot_encoding_to_sequence(sequence_onehot)
designed_sequences.append(designed_sequence)
# get final sequence
for i in range(n_sequences):
best_mutation = enhancer_optimizer.get_best(
mutated_predictions = mutagenesis_predictions[i],
original_prediction = baseline_prediction[i],
target = target,
**kwargs
)

designed_sequence = hot_encoding_to_sequence(
mutagenesis[
i,
best_mutation : best_mutation + 1,
:
]
)

designed_sequences.append(
designed_sequence
)

if return_intermediate:
intermediate_info_list[idx]["designed_sequence"] = designed_sequence
intermediate_info_list[i]["designed_sequence"] = designed_sequence

if return_intermediate:
return intermediate_info_list, designed_sequences
Expand Down
29 changes: 28 additions & 1 deletion src/crested/tl/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any, Callable

import numpy as np
import pyBigWig

Expand Down Expand Up @@ -93,10 +95,35 @@ def generate_motif_insertions(x, motif, flanks=(0, 0), masked_locations=None):

return np.concatenate(x_mut, axis=0), insertion_locations

class EnhancerOptimizer:
def __init__(
self,
optimize_func: Callable[..., np.intp]
) -> None:
self.optimize_func = optimize_func

def get_best(
self,
mutated_predictions: np.ndarray,
original_prediction: np.ndarray,
target: int | np.ndarray,
**kwargs: dict[str, Any]
) -> np.intp:
return self.optimize_func(
mutated_predictions,
original_prediction,
target,
**kwargs
)

def _weighted_difference(
mutated_predictions, original_prediction, target, class_penalty_weights=None
mutated_predictions: np.ndarray,
original_prediction: np.ndarray,
target: int,
class_penalty_weights: np.ndarray | None = None
):
if len(original_prediction.shape) == 1:
original_prediction = original_prediction[None]
n_classes = original_prediction.shape[1]
penalty_factor = 1 / n_classes

Expand Down

0 comments on commit 81ba818

Please sign in to comment.