Skip to content

Commit

Permalink
small edits
Browse files Browse the repository at this point in the history
  • Loading branch information
Gibbsdavidl committed Nov 24, 2021
1 parent 9fe673f commit a861d04
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Gene Set Scoring on the Nearest Neighbor Graph (gssnng) for Single Cell RNA-seq

Works with AnnData objects stored as h5ad files. Takes values from adata.X.

Scoring functions:
Scoring functions, works with ranked or unranked data ("your mileage may vary"):
```
singscore: mean(ranks) / n, where n is length of gene set
Expand Down
43 changes: 27 additions & 16 deletions gssnng/score_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def average_score(su):
return((cnts_mean, std_su))


def mean_z(exprdat, su):
def mean_z(allexprvals, genesetvals):
"""
Average Z score
Expand All @@ -61,9 +61,9 @@ def mean_z(exprdat, su):
:param score_up: is the rank up or down? True or False
"""
# normalise the score for the number of genes in the signature
vals_mean = np.mean(exprdat)
vals_std = np.std(exprdat)
centered = [ (np.abs(x - vals_mean) / vals_std) for x in su ]
vals_mean = np.mean(allexprvals)
vals_std = np.std(allexprvals)
centered = [ (np.abs(x - vals_mean) / vals_std) for x in genesetvals ]
score = np.mean(centered)
return((score, vals_std))

Expand Down Expand Up @@ -132,18 +132,28 @@ def singscore(x, su, sig_len, norm_method):
return((norm_up, mad_up))


def expr_format(x, exprdat, geneset_genes):
def expr_format_2(x, exprcol, geneset_genes): #### this made things run twice as long!! ####
xset = set(x.index)
gene_overlap = xset.intersection(geneset_genes)
xsub = x.loc[gene_overlap]
sig_len_up = len(gene_overlap)
return( (xsub[exprcol], sig_len_up) )


def expr_format(x, exprcol, geneset_genes):
#### OPTIMIZE OPPORTUNITY HERE ####
sig_len_up = len(geneset_genes)
su = []
for j in geneset_genes:
if j in x.index:
su.append(exprdat[j])
su.append(x[exprcol][j])
else:
sig_len_up = sig_len_up - 1
return( (su, sig_len_up) )


def method_selector(gs, x, exprdat, geneset_genes, method, method_params):

def method_selector(gs, x, exprcol, geneset_genes, method, method_params):
"""
:param gs: the gene set
:param x: the gene expr data frame
Expand All @@ -156,7 +166,8 @@ def method_selector(gs, x, exprdat, geneset_genes, method, method_params):
:return: dictionary of results
"""

(su, sig_len) = expr_format(x, exprdat, geneset_genes)
(su, sig_len) = expr_format(x, exprcol, geneset_genes)
exprdat = x[exprcol]

if method == 'singscore':
res0 = singscore(exprdat, su, sig_len, method_params['normalization'])
Expand Down Expand Up @@ -204,30 +215,30 @@ def scorefun(gs,
"""

if (gs.mode == 'UP') and (ranked == False):
res0 = method_selector(gs, x, x.counts, gs.genes_up, method, method_params)
res0 = method_selector(gs, x, 'counts', gs.genes_up, method, method_params)
res1 = dict(barcode = barcode, name=gs.name, mode=gs.mode, score=res0[0], var=res0[1])

elif (gs.mode == 'DN') and (ranked == False):
res0 = method_selector(gs, x, x.counts, gs.genes_dn, method, method_params)
res0 = method_selector(gs, x, 'counts', gs.genes_dn, method, method_params)
res1 = dict(barcode = barcode, name=gs.name, mode=gs.mode, score=res0[0], var=res0[1])

elif (gs.mode == 'BOTH') and (ranked == False):
res0_up = method_selector(gs, x, x.counts, gs.genes_up , method, method_params)
res0_dn = method_selector(gs, x, x.counts, gs.genes_dn, method, method_params)
res0_up = method_selector(gs, x, 'counts', gs.genes_up, method, method_params)
res0_dn = method_selector(gs, x, 'counts', gs.genes_dn, method, method_params)
res1 = dict(barcode = barcode, name=gs.name, mode=gs.mode,
score=(res0_up[0]+res0_dn[0]), var=(res0_up[1]+res0_dn[1]))

elif (gs.mode == 'UP') and (ranked == True):
res0 = method_selector(gs, x, x.uprank, gs.genes_up, method, method_params)
res0 = method_selector(gs, x, 'uprank', gs.genes_up, method, method_params)
res1 = dict(barcode = barcode, name=gs.name, mode=gs.mode, score=res0[0], var=res0[1])

elif (gs.mode == 'DN') and (ranked == True):
res0 = method_selector(gs, x, x.dnrank, gs.genes_dn, method, method_params)
res0 = method_selector(gs, x, 'dnrank', gs.genes_dn, method, method_params)
res1 = dict(barcode = barcode, name=gs.name, mode=gs.mode, score=res0[0], var=res0[1])

elif (gs.mode == 'BOTH') and (ranked == True):
res0_up = method_selector(gs, x, x.uprank, gs.genes_up , method, method_params)
res0_dn = method_selector(gs, x, x.dnrank, gs.genes_dn, method, method_params)
res0_up = method_selector(gs, x, 'uprank', gs.genes_up , method, method_params)
res0_dn = method_selector(gs, x, 'dnrank', gs.genes_dn, method, method_params)
res1 = dict(barcode = barcode, name=gs.name, mode=gs.mode,
score=(res0_up[0]+res0_dn[0]), var=(res0_up[1]+res0_dn[1]))

Expand Down
9 changes: 6 additions & 3 deletions gssnng/test/test_score_all_cells_all_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import scanpy as sc
from gssnng.score_cells import with_gene_sets

import time

def test_score_all_sets_fun(adata, genesets):
res0 = with_gene_sets(adata=adata, gene_set_file=genesets, score_method='mean_z', method_params=dict(),
Expand All @@ -17,12 +17,15 @@ def test_score_all_sets():
print("computing knn...")
sc.pp.neighbors(q2, n_neighbors=32)
print('scoring...')
t0 = time.time()
print('start time: ' + str(t0))
score_list = test_score_all_sets_fun(q2, gs)
print('******DONE*******')
t1 = time.time()
print('end time: ' + str(t1))
print('TOTAL TIME: ' + str(t1-t0))
print(q2.obs.head())
print(q2.obs.columns)
#q.write_h5ad('data/pbmc3k_lm22_scores.h5ad')

test_score_all_sets()
print('test score_all_sets done')

0 comments on commit a861d04

Please sign in to comment.