Skip to content

Commit

Permalink
Improvements to qc script
Browse files Browse the repository at this point in the history
  • Loading branch information
pcm32 committed Oct 14, 2024
1 parent 7796e54 commit 873197a
Showing 1 changed file with 231 additions and 21 deletions.
252 changes: 231 additions & 21 deletions tools/tertiary-analysis/scanpy/scripts/sc_qc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns

# import seaborn as sns


def main():
Expand All @@ -12,7 +13,10 @@ def main():
)
parser.add_argument("adata_file", type=str, help="Path to AnnData object file")
parser.add_argument(
"sample_field", type=str, help="Field in the obs for the sample identifier"
"--sample_field",
type=str,
default="Sample_ID",
help="Field in the obs for the sample identifier"
)
parser.add_argument(
"--output_format",
Expand All @@ -28,6 +32,34 @@ def main():
metavar=("width", "height"),
help="Size of the plots",
)
# add an argument for general plot title font size
parser.add_argument(
"--title_font_size",
type=int,
default=12,
help="General plot title font size",
)
# add an argument for general plot label font size
parser.add_argument(
"--label_font_size",
type=int,
default=8,
help="General plot label font size",
)
# add an argument for general plot legend font size
parser.add_argument(
"--legend_font_size",
type=int,
default=10,
help="General plot legend font size",
)
# add an argument for the gene symbols field
parser.add_argument(
"--gene_symbols_field",
type=str,
default="gene_symbols",
help="Field in the var for the gene symbols",
)
parser.add_argument(
"--percent_mito_field",
type=str,
Expand Down Expand Up @@ -58,6 +90,13 @@ def main():
default="doublet_score",
help="Field in the obs for the doublet score",
)
# add an argument for an embedding to plot the cells
parser.add_argument(
"--embedding",
type=str,
default=None,
help="Embedding to plot the cells",
)
args = parser.parse_args()

# Load AnnData object
Expand All @@ -67,8 +106,14 @@ def main():
if args.plot_size:
sc.settings.figsize = tuple(args.plot_size)

# Set output format
sc.settings.set_figure_params(format=args.output_format)
# set scanpy general plot font size and output format
sc.settings.set_figure_params(scanpy=True,
fontsize=args.label_font_size,
format=args.output_format)
# disable FutureWarning
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

run_quality_control = False
if "n_genes_by_counts" not in adata.obs.columns:
Expand All @@ -77,13 +122,30 @@ def main():
run_quality_control = True

qc_vars = []
fields = [
"n_genes_by_counts",
"total_counts",
args.percent_mito_field,
args.percent_ribo_field
]
# calculate mitochondrial genes if not provided
if args.percent_mito_field not in adata.obs.columns:
qc_vars.append(args.mito_field)
# calculate ribo metrics if not provided
if args.ribo_field not in adata.var.columns:
# create a new column with the name args.ribo_field where genes that
# have in the gene symbols field the pattern ^RP[SL] are
# marked as true
print(f"Creating {args.ribo_field} column")
adata.var[args.ribo_field] = adata.var[args.gene_symbols_field].str.contains(
"^RP[SL]"
)
print(f"Number of ribosomal genes: {adata.var[args.ribo_field].sum()}")
if args.percent_ribo_field not in adata.obs.columns:
qc_vars.append(args.ribo_field)

print(f"Calculating QC metrics for {len(qc_vars)} variables")

if len(qc_vars) > 0 or run_quality_control:
sc.pp.calculate_qc_metrics(
adata,
Expand All @@ -95,16 +157,54 @@ def main():
adata.obs["n_genes"] = adata.obs["n_genes_by_counts"]
adata.var["n_counts"] = adata.var["total_counts"]
adata.var["n_cells"] = adata.var["n_cells_by_counts"]

# Define thresholds
high_umi_threshold = adata.obs['n_counts'].quantile(0.95) # Top 5% most UMI counts
low_umi_threshold = adata.obs['n_counts'].quantile(0.05) # Bottom 5% least UMI counts
high_mito_threshold = adata.obs[args.percent_mito_field].quantile(0.90) # Top 10% pct mitochondrial genes

from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures

# Polynomial regression to account for curvature in the n_counts vs. n_genes relationship
poly = PolynomialFeatures(degree=2)
X_poly = poly.fit_transform(adata.obs[['n_counts']])
model = LinearRegression()
model.fit(X_poly, adata.obs['n_genes'])
predicted_counts = model.predict(X_poly)

# Calculate residuals
residuals = adata.obs['n_genes'] - predicted_counts
outlier_threshold = residuals.abs().quantile(0.95) # Top 5% residuals as outliers

# Initialize diagnosis column
adata.obs['auto_diagnosis'] = 'Healthy'

# Identify outliers
outliers = residuals.abs() > outlier_threshold
adata.obs.loc[outliers, 'auto_diagnosis'] = 'Outlier'


# Identify stressed/dying/apoptotic cells
stressed_cells = (adata.obs['n_counts'] > high_umi_threshold) & (adata.obs[args.percent_mito_field] > high_mito_threshold)
adata.obs.loc[stressed_cells, 'auto_diagnosis'] = 'Stressed/Dying/Apoptotic'

# Identify poor-quality cells
poor_quality_cells = (adata.obs['n_counts'] < low_umi_threshold) & (adata.obs[args.percent_mito_field] > high_mito_threshold)
adata.obs.loc[poor_quality_cells, 'auto_diagnosis'] = 'Poor-Quality'

# Print diagnosis summary
print(adata.obs['auto_diagnosis'].value_counts())
# make a barplot of the auto_diagnosis, omitting the healthy cells from the plot
# but writing the number of healthy cells in the title. Plot per sample
healthy_cells = adata.obs['auto_diagnosis'] == 'Healthy'
healthy_count = healthy_cells.sum()

# General quality for whole dataset
plt.figure()
ax = sc.pl.violin(
adata,
[
"n_genes_by_counts",
"total_counts",
args.percent_mito_field,
args.percent_ribo_field,
],
fields,
jitter=False,
multi_panel=True,
show=False,
Expand All @@ -115,11 +215,18 @@ def main():

# Generate quality control plots
generate_violin_plots(
adata, args.sample_field, args.percent_mito_field, format=args.output_format
adata, args.sample_field, args.percent_mito_field,
args.percent_ribo_field, format=args.output_format
)
generate_scatter_plot(
adata,
args.sample_field,
percent_mito_field=args.percent_mito_field,
)
generate_scatter_plot(
adata,
args.sample_field,
y='log1p_n_genes_by_counts',
percent_mito_field=args.percent_mito_field,
)
if args.doublet_score_field in adata.obs.columns:
Expand All @@ -130,8 +237,24 @@ def main():
format=args.output_format,
)
else:
print("Doublet score field provided not in adata.obs.columns, skipping plot.")
print(
"Doublet score field provided not in adata.obs.columns, " + "skipping plot."
)
generate_complexity_plot(adata, args.sample_field, format=args.output_format)

if args.embedding:
generate_embedding_plot(
adata,
fields=fields + [args.sample_field, 'auto_diagnosis'],
embedding=args.embedding,
format=args.output_format,
)

generate_barplot(adata[~healthy_cells],
groups_field=args.sample_field,
props_field='auto_diagnosis',
figure_path='diagnosis_barplot.pdf',
topic_for_title=f"(Total Healthy/Unhealthy cells: {healthy_count}/{adata.n_obs - healthy_count})")
# generate_scatter_by_sample(
# adata,
# sample_field=args.sample_field,
Expand All @@ -140,10 +263,69 @@ def main():
# )


def generate_barplot(
adata, groups_field, props_field, figure_path=None, topic_for_title=None
):
"""
Generate a proportional bar plot from an AnnData object.
Parameters:
adata (AnnData): The input AnnData object containing the data to plot.
groups_field (str): The column in adata.obs to group the data by.
props_field (str): The column in adata.obs to plot as proportions.
figure_path (str, optional): The path to save the generated figure. If not provided, the figure is not saved.
topic_for_title (str, optional): The topic to be used in the figure title, goes after {props_field} proportion of {topic_for_title} per {groups_field}.
Returns:
matplotlib.figure.Figure: The generated bar plot.
"""
props_plot_data = adata.obs[[groups_field, props_field]]
# props_plot_data[groups_field] = props_plot_data[groups_field].cat.reorder_categories(['control', '2 days', '7 days', '10 days', '14 days'])
# make a 100% stacked bar plot of props_plot_data, plotting phase counts grouped by cell_line_persister

grouped = props_plot_data.groupby([groups_field, props_field]).size().unstack()
# proportions = grouped.div(grouped.sum(axis=1), axis=0)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
plt.gca().set_prop_cycle(color=colors[2:5])
grouped.plot(kind="bar", stacked=False, figsize=(8, 6))
if topic_for_title is not None:
plt.title(f"{props_field} cells {topic_for_title}\nper {groups_field}")
else:
plt.title(f"{props_field} cells\nper {groups_field}")
plt.xlabel(groups_field)
# plt.xticks(rotation=45)
plt.ylabel("Number of cells")

# save plot to PDF file
if figure_path is not None:
plt.savefig(figure_path, bbox_inches="tight")
return plt.figure()


def generate_embedding_plot(
adata,
fields,
embedding,
format="pdf"
):
# Embedding plot
plt.figure()
sc.pl.embedding(
adata,
basis=embedding,
color=fields,
show=False,
ncols=1
)
plt.savefig(f"embedding_plots.{format}", bbox_inches="tight")
plt.close()


def generate_violin_plots(
adata,
sample_field,
percent_mito_field="percent_mito",
percent_ribo_field="percent_ribo",
format="pdf",
gene_symbols_field="gene_symbols",
):
Expand Down Expand Up @@ -171,7 +353,7 @@ def generate_violin_plots(
groupby=sample_field,
ax=ax,
title="Number of Genes per Cell (Separated by Sample)",
show=False
show=False,
# show=True,
# save="_n_genes_per_cell",
)
Expand All @@ -187,20 +369,35 @@ def generate_violin_plots(
percent_mito_field,
groupby=sample_field,
ax=ax,
title="Percentage of Mitochondrial Genes per Cell (Separated by Sample)",
title="Percentage of Mitochondrial " + "Genes per Cell (Separated by Sample)",
show=False,
)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.savefig(f"percent_mito_per_cell.{format}", bbox_inches="tight")
plt.close()

# Percentage of ribosomal genes per cell
plt.figure()
ax = plt.gca()
sc.pl.violin(
adata,
percent_ribo_field,
groupby=sample_field,
ax=ax,
title="Percentage of Ribosomal " + "Genes per Cell (Separated by Sample)",
show=False,
)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.savefig(f"percent_ribo_per_cell.{format}", bbox_inches="tight")
plt.close()

# highest expressed genes per cell
plt.figure()
ax = sc.pl.highest_expr_genes(
adata, n_top=30, gene_symbols=gene_symbols_field, show=False
)
# set title of ax
ax.set_title(f"Highest expressed genes per cell (Separated by Sample)")
ax.set_title("Highest expressed genes per cell\nby Sample")
plt.savefig(f"highest_expr_genes.{format}", bbox_inches="tight")
plt.close()

Expand All @@ -223,17 +420,18 @@ def generate_violin_plots(
def generate_scatter_plot(
adata,
sample_field,
y="n_genes",
percent_mito_field="percent_mito",
):
# Scatter plot of UMIs vs genes detected
plt.figure()
sc.pl.scatter(
adata,
x="n_counts",
y="n_genes",
y=y,
color=sample_field,
title="UMIs vs Genes Detected (Separated by Sample)",
save="_umi_vs_genes_detected",
title="UMIs vs Genes Detected (by Sample)",
save=f"_umi_vs_{y}_detected",
show=False,
)
plt.close()
Expand All @@ -243,14 +441,26 @@ def generate_scatter_plot(
sc.pl.scatter(
adata,
x="n_counts",
y="n_genes",
y=y,
color=percent_mito_field,
title="UMIs vs Genes Detected (Colored by Mitochondrial Gene Ratio)",
save="_umi_vs_genes_detected_colored_by_mito",
title="UMIs vs Genes Detected (by Mitochondrial Gene Ratio)",
save=f"_umi_vs_{y}_detected_colored_by_mito",
show=False,
)
plt.close()

plt.figure()
sc.pl.scatter(
adata,
x='n_counts',
y=y,
color='auto_diagnosis',
title="UMIs vs Genes Detected (by Mitochondrial Gene Ratio)",
save=f"_umi_vs_{y}_detected_colored_by_auto_diagnosis",
show=False
)
plt.close()


def generate_scatter_by_sample(
adata, sample_field, percent_mito_field="percent_mito", format="pdf"
Expand Down

0 comments on commit 873197a

Please sign in to comment.