Skip to content

Commit

Permalink
make class names default argument in contrib scores calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMahieu committed Jul 16, 2024
1 parent 876ac20 commit 7e26dc4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
52 changes: 34 additions & 18 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,8 @@ def predict_sequence(self, sequence: str) -> np.ndarray:

def calculate_contribution_scores(
self,
class_names: list[str],
anndata: AnnData | None = None,
class_names: list[str] | None = None,
method: str = "expected_integrated_grad",
store_in_varm: bool = False,
) -> tuple[np.ndarray, np.ndarray] | None:
Expand All @@ -628,12 +628,12 @@ def calculate_contribution_scores(
Parameters
----------
class_names
List of class names to calculate the contribution scores for (should match anndata.obs_names)
If the list is empty, the contribution scores for the 'combined' class will be calculated.
anndata
Anndata object to store the contribution scores in as a .varm[class_name] attribute.
If None, will only return the contribution scores without storing them.
class_names
List of class names to calculate the contribution scores for (should match anndata.obs_names)
If None, the contribution scores for the 'combined' class will be calculated.
method
Method to use for calculating the contribution scores.
Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'.
Expand All @@ -646,6 +646,8 @@ def calculate_contribution_scores(
--------
crested.pl.patterns.contribution_scores
"""
if isinstance(class_names, str):
class_names = [class_names]
self._check_contribution_scores_params(class_names)

if self.anndatamodule.predict_dataset is None:
Expand All @@ -657,13 +659,16 @@ def calculate_contribution_scores(

all_class_names = list(self.anndatamodule.adata.obs_names)

if class_names is not None:
if len(class_names) > 0:
n_classes = len(class_names)
class_indices = [
all_class_names.index(class_name) for class_name in class_names
]
varm_names = class_names
else:
logger.warning(
"No class names provided. Calculating contribution scores for the 'combined' class."
)
n_classes = 1 # 'combined' class
class_indices = [None]
varm_names = ["combined"]
Expand Down Expand Up @@ -727,7 +732,7 @@ def calculate_contribution_scores(
def calculate_contribution_scores_regions(
self,
region_idx: list[str] | str,
class_names: list[str] | None = None,
class_names: list[str],
method: str = "expected_integrated_grad",
disable_tqdm: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
Expand All @@ -743,7 +748,7 @@ def calculate_contribution_scores_regions(
Region(s) for which to calculate the contribution scores in the format "chr:start-end".
class_names
List of class names to calculate the contribution scores for (should match anndata.obs_names)
If None, the contribution scores for the 'combined' class will be calculated.
If the list is empty, the contribution scores for the 'combined' class will be calculated.
method
Method to use for calculating the contribution scores.
Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'.
Expand All @@ -761,6 +766,9 @@ def calculate_contribution_scores_regions(
if isinstance(region_idx, str):
region_idx = [region_idx]

if isinstance(class_names, str):
class_names = [class_names]

if self.anndatamodule.predict_dataset is None:
self.anndatamodule.setup("predict")

Expand All @@ -779,7 +787,7 @@ def calculate_contribution_scores_regions(
def calculate_contribution_scores_sequence(
self,
sequences: list[str] | str,
class_names: list[str] | None = None,
class_names: list[str],
method: str = "expected_integrated_grad",
disable_tqdm: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
Expand All @@ -795,7 +803,7 @@ def calculate_contribution_scores_sequence(
Sequence(s) for which to calculate the contribution scores.
class_names
List of class names to calculate the contribution scores for (should match anndata.obs_names)
If None, the contribution scores for the 'combined' class will be calculated.
If the list is empty, the contribution scores for the 'combined' class will be calculated.
method
Method to use for calculating the contribution scores.
Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'.
Expand Down Expand Up @@ -826,12 +834,15 @@ def calculate_contribution_scores_sequence(

all_class_names = list(self.anndatamodule.adata.obs_names)

if class_names is not None:
if len(class_names) > 0:
n_classes = len(class_names)
class_indices = [
all_class_names.index(class_name) for class_name in class_names
]
else:
logger.warning(
"No class names provided. Calculating contribution scores for the 'combined' class."
)
n_classes = 1 # 'combined' class
class_indices = [None]

Expand Down Expand Up @@ -1349,19 +1360,24 @@ def _check_predict_params(self, anndata: AnnData | None, model_name: str | None)
)

@log_and_raise(ValueError)
def _check_contribution_scores_params(self, class_names: list | None):
def _check_contribution_scores_params(self, class_names: list):
"""Check if the necessary parameters are set for the calculate_contribution_scores method."""
if not self.model:
raise ValueError(
"Model not set. Please load a model from pretrained using Crested.load_model(...) before calling calculate_contribution_scores_(regions)."
)
if class_names is not None:
all_class_names = list(self.anndatamodule.adata.obs_names)
for class_name in class_names:
if class_name not in all_class_names:
raise ValueError(
f"Class name {class_name} not found in anndata.obs_names."
)
# check if class names is a list
if not isinstance(class_names, list):
raise ValueError(
"Class names should be a list of class names or an empty list (if calculating the average accross classes)."
)

all_class_names = list(self.anndatamodule.adata.obs_names)
for class_name in class_names:
if class_name not in all_class_names:
raise ValueError(
f"Class name {class_name} not found in anndata.obs_names."
)

def __repr__(self):
return f"Crested(data={self.anndatamodule is not None}, model={self.model is not None}, config={self.config is not None})"
6 changes: 6 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,15 @@ def test_peak_regression():
scores, one_hot_encoded_sequences = trainer.calculate_contribution_scores_sequence(
sequence, class_names=["cell_1", "cell_2"], method="integrated_grad"
)

assert scores.shape == (1, 2, 600, 4)
assert one_hot_encoded_sequences.shape == (1, 600, 4)

scores, one_hot_encoded_sequences = trainer.calculate_contribution_scores_regions(
region_idx=["chr1:1000-1600", "chr2:2000-2600"],
class_names=[],
method="integrated_grad",
)
trainer.enhancer_design_in_silico_evolution(
target_class="cell_1", n_sequences=1, n_mutations=1
)

0 comments on commit 7e26dc4

Please sign in to comment.