Skip to content

Commit

Permalink
added gene scoring and plotting + notebook bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nkempynck committed Sep 26, 2024
1 parent d31be31 commit 982e25b
Show file tree
Hide file tree
Showing 8 changed files with 384 additions and 102 deletions.
8 changes: 4 additions & 4 deletions docs/tutorials/enhancer_code_analysis.ipynb

Large diffs are not rendered by default.

242 changes: 146 additions & 96 deletions docs/tutorials/model_training_and_eval.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ dependencies = [
"pybigtools>=0.2.0",
"seaborn",
"pooch",
"scanpy"
"scanpy",
"pybigwig"
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions src/crested/pl/hist/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from ._distribution import distribution
from ._locus_scoring import locus_scoring
62 changes: 62 additions & 0 deletions src/crested/pl/hist/_locus_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Distribution plots."""

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from anndata import AnnData
from loguru import logger

def locus_scoring(scores, coordinates, range, gene_start=None, gene_end=None, title='Predictions across Genomic Regions', bigwig_values=None, bigwig_midpoints=None, filename=None):
"""
Plot the predictions as a line chart over the entire genomic input and optionally indicate the gene locus.
Additionally, plot values from a bigWig file if provided.
Parameters:
scores (np.array): An array of prediction scores for each window.
coordinates (np.array): An array of tuples, each containing the chromosome name and the start and end positions of the sequence for each window.
model_class (int): The class index to plot from the prediction scores.
gene_start (int, optional): The start position of the gene locus to highlight on the plot.
gene_end (int, optional): The end position of the gene locus to highlight on the plot.
title (str): The title of the plot.
bigwig_values (np.array, optional): A numpy array of values extracted from a bigWig file for the same coordinates.
bigwig_midpoints (list, optional): A list of base pair positions corresponding to the bigwig_values.
"""

# Extract the midpoints of the coordinates for plotting
midpoints = [(int(start) + int(end)) // 2 for _, start, end in coordinates]

# Plotting predictions
plt.figure(figsize=(30, 10))

# Top plot: Model predictions
plt.subplot(2, 1, 1)
plt.plot(np.arange(range[0], range[1]), scores, marker='o', linestyle='-', color='b', label='Prediction Score')
if gene_start is not None and gene_end is not None:
plt.axvspan(gene_start, gene_end, color='red', alpha=0.3, label='Gene Locus')
plt.title(title)
plt.xlabel('Genomic Position')
plt.ylabel('Prediction Score')
plt.ylim(bottom=0)
plt.xticks(rotation=90)
plt.grid(True)
plt.legend()

# Bottom plot: bigWig values
if bigwig_values is not None and bigwig_midpoints is not None:
plt.subplot(2, 1, 2)
plt.plot(bigwig_midpoints, bigwig_values, linestyle='-', color='g', label='bigWig Values')
if gene_start is not None and gene_end is not None:
plt.axvspan(gene_start, gene_end, color='red', alpha=0.3, label='Gene Locus')
plt.xlabel('Genomic Position')
plt.ylabel('bigWig Values')
plt.xticks(rotation=90)
plt.ylim(bottom=0)
plt.grid(True)
plt.legend()

plt.tight_layout()
if filename:
plt.savefig(filename)
plt.show()
2 changes: 2 additions & 0 deletions src/crested/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import data, losses, metrics, zoo
from ._configs import TaskConfig, default_configs
from ._crested import Crested
from ._utils import extract_bigwig_values_per_bp


def _optional_function_warning(*args, **kwargs):
Expand Down Expand Up @@ -65,6 +66,7 @@ def _optional_function_warning(*args, **kwargs):
"TaskConfig",
"default_configs",
"Crested",
"extract_bigwig_values_per_bp",
]

if MODISCOLITE_AVAILABLE:
Expand Down
125 changes: 125 additions & 0 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from anndata import AnnData
from loguru import logger
from tqdm import tqdm
from pysam import FastaFile

from crested._logging import log_and_raise
from crested.tl import TaskConfig
Expand Down Expand Up @@ -657,6 +658,130 @@ def predict_sequence(self, sequence: str) -> np.ndarray:

return predictions

def score_gene_locus(
self,
chr_name: str,
gene_start: int,
gene_end: int,
class_name: str,
strand: str = '+',
upstream: int = 50000,
downstream: int = 10000,
window_size: int = 2114,
central_size: int = 1000,
step_size: int = 50,
):
"""
Score regions upstream and downstream of a gene locus using the model's prediction.
The model predicts a value for the central 1000bp of each window.
Parameters:
----------
chr_name : str
The chromosome name (e.g., 'chr12').
gene_start : int
The start position of the gene locus (TSS for + strand).
gene_end : int
The end position of the gene locus (TSS for - strand).
class_name : str
Output class name for prediction.
strand : str
'+' for positive strand, '-' for negative strand. Default '+'.
upstream : int
Distance upstream of the gene to score. Default 50 000.
downstream : int
Distance downstream of the gene to score. Default 10 000.
window_size : int
Size of the window to use for scoring. Default 2114.
central_size : int
Size of the central region that the model predicts for. Default 1000.
step_size : int
Distance between consecutive windows. Default 50.
Returns:
--------
scores : np.array
An array of prediction scores across the entire genomic range.
coordinates : np.array
An array of tuples, each containing the chromosome name and the start and end positions of the sequence for each window.
min_loc : int
Start position of the entire scored region.
max_loc : int
End position of the entire scored region.
tss_position : int
The transcription start site (TSS) position.
"""
# Adjust upstream and downstream based on the strand
if strand == '+':
start_position = gene_start - upstream
end_position = gene_end + downstream
tss_position = gene_start # TSS is at the gene_start for positive strand
elif strand == '-':
end_position = gene_end + upstream
start_position = gene_start - downstream
tss_position = gene_end # TSS is at the gene_end for negative strand
else:
raise ValueError("Strand must be '+' or '-'.")

total_length = abs(end_position - start_position)

# Ratio to normalize the score contributions
ratio = central_size / step_size

# Initialize an array to store the scores, filled with zeros
scores = np.zeros(total_length)

# List to store coordinates of each window
coordinates = []

# Get class index
all_class_names = list(self.anndatamodule.adata.obs_names)
idx = all_class_names.index(class_name)

genome = FastaFile(self.anndatamodule.genome_file)

# Generate all windows and one-hot encode the sequences in parallel
all_sequences = []
all_coordinates = []

for pos in range(start_position, end_position, step_size):
window_start = pos
window_end = pos + window_size

# Ensure the window stays within the bounds of the region
if window_end > end_position:
break

# Fetch the sequence
seq = genome.fetch(chr_name, window_start, window_end).upper()

# One-hot encode the sequence (you would need to ensure this function is available)
seq_onehot = one_hot_encode_sequence(seq)

all_sequences.append(seq_onehot)
all_coordinates.append((chr_name, int(window_start), int(window_end)))

# Stack sequences for batch processing
all_sequences = np.squeeze(np.stack(all_sequences),axis=1)

# Perform batched predictions
predictions = self.model.predict(all_sequences, verbose=0)

# Map predictions to the score array
for i, (pos, prediction) in enumerate(zip(range(start_position, end_position, step_size), predictions)):
window_start = pos
central_start = pos + (window_size - central_size) // 2
central_end = central_start + central_size

scores[central_start - start_position:central_end - start_position] += prediction[idx]
#if strand == '+':
# scores[central_start - start_position:central_end - start_position] += prediction[idx]
#else:
# scores[total_length - (central_end - start_position):total_length - (central_start - start_position)] += prediction[idx]

# Normalize the scores based on the number of times each position is included in the central window
return scores / ratio, np.array(all_coordinates), start_position, end_position, tss_position

def calculate_contribution_scores(
self,
class_names: list[str],
Expand Down
43 changes: 42 additions & 1 deletion src/crested/tl/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import numpy as np
import pyBigWig


def get_hot_encoding_table(
Expand Down Expand Up @@ -190,4 +191,44 @@ def get_value_from_dataframe(df: pd.DataFrame, row_name: str, column_name: str):
return f"Row index '{row_index}' is out of bounds for DataFrame with {len(df)} rows."
except Exception as e:
# Handle any other unexpected exceptions
return f"An error occurred: {str(e)}"
return f"An error occurred: {str(e)}"


def extract_bigwig_values_per_bp(bigwig_file, coordinates):
"""
Extract per-base pair values from a bigWig file for the given genomic coordinates.
Parameters:
bigwig_file (str): Path to the bigWig file.
coordinates (np.array): An array of tuples, each containing the chromosome name and the start and end positions of the sequence.
Returns:
bw_values (np.array): A numpy array of values from the bigWig file for each base pair in the specified range.
all_midpoints (list): A list of all base pair positions covered in the specified coordinates.
"""

# Calculate the full range of coordinates
min_coord = min([int(start) for _, start, _ in coordinates])
max_coord = max([int(end) for _, _, end in coordinates])

# Initialize the list to store values
bw_values = []

# Open the bigWig file
bw = pyBigWig.open(bigwig_file)

# Iterate over each chromosome (all coordinates should be for the same chromosome)
chrom = coordinates[0][0] # Assuming all coordinates are for the same chromosome

# Extract per-base values
bw_values = bw.values(chrom, min_coord, max_coord)

# Replace NaN with 0
bw_values = np.nan_to_num(bw_values, nan=0)

# Generate the list of all base pair positions
all_midpoints = list(range(min_coord, max_coord))

bw.close()

return bw_values, all_midpoints

0 comments on commit 982e25b

Please sign in to comment.