Skip to content

Commit

Permalink
Merge pull request #17 from aertslab/return_embeddings
Browse files Browse the repository at this point in the history
get_embeddings based on layer name
  • Loading branch information
LukasMahieu authored Sep 17, 2024
2 parents 84a205c + 4d87e78 commit 9f737da
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,44 @@ def test(self, return_metrics: bool = False) -> dict | None:
if return_metrics:
return evaluation_metrics

def get_embeddings(
self,
layer_name: str = "global_average_pooling1d",
anndata: AnnData | None = None,
) -> np.ndarray:
"""
Extract embeddings from a specified layer in the model for all regions in the dataset.
If anndata is provided, it will add the embeddings to anndata.obsm[layer_name].
Parameters
----------
anndata
Anndata object containing the data.
layer_name
The name of the layer from which to extract the embeddings.
Returns
-------
Embeddings of shape (N, D), where D is the size of the embedding layer.
"""
if layer_name not in [layer.name for layer in self.model.layers]:
raise ValueError(f"Layer '{layer_name}' not found in model.")
embedding_model = keras.models.Model(
inputs=self.model.input, outputs=self.model.get_layer(layer_name).output
)
if self.anndatamodule.predict_dataset is None:
self.anndatamodule.setup("predict")
predict_loader = self.anndatamodule.predict_dataloader
n_predict_steps = (
len(predict_loader) if os.environ["KERAS_BACKEND"] == "tensorflow" else None
)
embeddings = embedding_model.predict(predict_loader.data, steps=n_predict_steps)

if anndata is not None:
anndata.obsm[layer_name] = embeddings
return embeddings

def predict(
self,
anndata: AnnData | None = None,
Expand Down

0 comments on commit 9f737da

Please sign in to comment.