From 1dcd441d4e931909007f06a60c9e285c994699b0 Mon Sep 17 00:00:00 2001 From: K D Ahlquist <6993450+kdahlo@users.noreply.github.com> Date: Tue, 4 Apr 2023 11:39:27 -0400 Subject: [PATCH] Added support and testing for MNPs --- spliceai/utils.py | 13 +++++++++---- tests/test_delta_score.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/spliceai/utils.py b/spliceai/utils.py index 991afd1..3237fa7 100644 --- a/spliceai/utils.py +++ b/spliceai/utils.py @@ -135,10 +135,6 @@ def get_delta_scores(record, ann, dist_var, mask): if '<' in record.alts[j] or '>' in record.alts[j]: continue - if len(record.ref) > 1 and len(record.alts[j]) > 1: - delta_scores.append("{}|{}|.|.|.|.|.|.|.|.".format(record.alts[j], genes[i])) - continue - dist_ann = ann.get_pos_data(idxs[i], record.pos) pad_size = [max(wid//2+dist_ann[0], 0), max(wid//2-dist_ann[1], 0)] ref_len = len(record.ref) @@ -174,6 +170,15 @@ def get_delta_scores(record, ann, dist_var, mask): np.max(y_alt[:, cov//2:cov//2+alt_len], axis=1)[:, None, :], y_alt[:, cov//2+alt_len:]], axis=1) + #MNP handling + elif ref_len > 1 and alt_len > 1: + zblock = np.zeros((1,ref_len-1,3)) + y_alt = np.concatenate([ + y_alt[:, :cov//2], + np.max(y_alt[:, cov//2:cov//2+alt_len], axis=1)[:, None, :], + zblock, + y_alt[:, cov//2+alt_len:]], + axis=1) y = np.concatenate([y_ref, y_alt]) diff --git a/tests/test_delta_score.py b/tests/test_delta_score.py index 4630095..5560aee 100755 --- a/tests/test_delta_score.py +++ b/tests/test_delta_score.py @@ -44,3 +44,17 @@ def test_get_delta_score_donor(self): self.assertEqual(scores, ['T|TUBB8|0.01|0.18|0.15|0.62|-2|110|-190|0']) scores = get_delta_scores(record, self.ann_without_prefix, 500, 0) self.assertEqual(scores, ['T|TUBB8|0.01|0.18|0.15|0.62|-2|110|-190|0']) + + def test_get_delta_score_mnp(self): + + record = Record('10', 94077, 'ACT', ['CCT']) + scores = get_delta_scores(record, self.ann, 50, 0) + self.assertEqual(scores, ['CCT|TUBB8|0.07|0.27|0.00|0.01|0|-23|19|-22']) + scores = get_delta_scores(record, self.ann_without_prefix, 50, 0) + self.assertEqual(scores, ['CCT|TUBB8|0.07|0.27|0.00|0.01|0|-23|19|-22']) + + record = Record('10', 94555, 'CGA', ['TGA']) + scores = get_delta_scores(record, self.ann, 50, 0) + self.assertEqual(scores, ['TGA|TUBB8|0.01|0.00|0.11|0.62|-2|-6|-23|0']) + scores = get_delta_scores(record, self.ann_without_prefix, 50, 0) + self.assertEqual(scores, ['TGA|TUBB8|0.01|0.00|0.11|0.62|-2|-6|-23|0'])