Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
kmoad committed Apr 24, 2024
2 parents 79069bf + e071cae commit 176e4ad
Show file tree
Hide file tree
Showing 13 changed files with 292 additions and 598 deletions.
6 changes: 6 additions & 0 deletions cravat/base_converter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from cravat.config_loader import ConfigLoader


class BaseConverter(object):
IGNORE = "converter_ignore"

Expand All @@ -6,6 +9,9 @@ def __init__(self):
self.output_dir = None
self.run_name = None
self.input_assembly = None
self.module_name = self.__class__.__module__
config_loader = ConfigLoader()
self.conf = config_loader.get_module_conf(self.module_name)

def check_format(self, *args, **kwargs):
err_msg = (
Expand Down
1 change: 1 addition & 0 deletions cravat/cravat_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _initialize_converters(self):
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
converter = module.CravatConverter()
converter.input_assembly = self.input_assembly
if converter.format_name not in self.converters:
self.converters[converter.format_name] = converter
else:
Expand Down
139 changes: 118 additions & 21 deletions cravat/cravat_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,7 @@ async def getcount(self, level="variant", conn=None, cursor=None):
ftable = level
else:
ftable = level + "_filtered"
q = "select count(*) from " + ftable
await cursor.execute(q)
for row in await cursor.fetchone():
n = row
n = await self.exec_db(self.get_filtered_count, level=level)
if self.stdout == True:
print("#" + level)
print(str(n))
Expand Down Expand Up @@ -645,8 +642,110 @@ async def getiterator(self, level="variant", conn=None, cursor=None):
it = await cursor.fetchall()
return it

@staticmethod
def reaggregate_column(base_alias, meta):
column = meta['name']
function = meta.get('filter_reagg_function', None)
reagg_args = meta.get('filter_reagg_function_args', [])
reagg_source = meta.get('filter_reagg_source_column', None)

if not function:
return "{}.{}".format(base_alias, column)

reagg_template = "{}({}{}) OVER (PARTITION BY {}.base__uid ORDER BY sample.base__sample_id ROWS BETWEEN UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) {}"
quoted_args = ["'{}'".format(x) for x in reagg_args]
formatted_args = ",{}".format(",".join(quoted_args)) if reagg_args else ""
return reagg_template.format(function, reagg_source, formatted_args, base_alias, column)

@staticmethod
async def level_column_definitions(cursor, level):
await cursor.execute("select col_name, col_def from {}_header".format(level))
return {k: json.loads(v) for k, v in await cursor.fetchall()}

async def make_sample_filter_group(self, cursor, sample_filter):
sample_columns = await self.level_column_definitions(cursor, 'sample')
prefixes = {k: 'sample' for k in sample_columns.keys()}
filter_group = FilterGroup(sample_filter)
filter_group.add_prefixes(prefixes)
return filter_group

async def get_filtered_iterator(self, level="variant", conn=None, cursor=None):
bypassfilter = not(self.filter or self.filtersql or self.includesample or self.excludesample)
sql = await self.build_base_sql(cursor, level)

if level == 'variant' and self.filter and 'samplefilter' in self.filter and len(self.filter['samplefilter']['rules']) > 0:
sample_filter = self.filter['samplefilter']
variant_columns = await self.level_column_definitions(cursor, 'variant')

reaggregated_columns = [self.reaggregate_column('v', meta) for col, meta in variant_columns.items()]
sample_filters = self.build_sample_exclusions()
filter_group = await self.make_sample_filter_group(cursor, sample_filter)

sql = """
with base_variant as ({}),
scoped_sample as (
select *
from sample
where 1=1
{}
)
select distinct {}
from base_variant v
join scoped_sample sample on sample.base__uid = v.base__uid
where {}
""".format(sql, sample_filters, ",".join(reaggregated_columns), filter_group.get_sql())

await cursor.execute(sql)
cols = [v[0] for v in cursor.description]
rows = await cursor.fetchall()

return cols, rows

async def get_filtered_count(self, level="variant", conn=None, cursor=None):

if level == 'variant' and self.filter and 'samplefilter' in self.filter and len(self.filter['samplefilter']['rules']) > 0:
sql = await self.build_base_sql(cursor, level)
sample_filter = self.filter['samplefilter']
variant_columns = await self.level_column_definitions(cursor, 'variant')

reaggregated_columns = [self.reaggregate_column('v', meta) for col, meta in variant_columns.items()]
sample_filters = self.build_sample_exclusions()
filter_group = await self.make_sample_filter_group(cursor, sample_filter)

sql = """
with base_variant as ({}),
scoped_sample as (
select *
from sample
where 1=1
{}
)
select count(distinct v.base__uid)
from base_variant v
join scoped_sample sample on sample.base__uid = v.base__uid
where {}
""".format(sql, sample_filters, filter_group.get_sql())
else:
sql = await self.build_base_sql(cursor, level, count=True)
await cursor.execute(sql)
rows = await cursor.fetchall()

return rows[0][0]

def build_sample_exclusions(self):
# this is needed because joining back to the sample table causes
# re-inclusion of sample data that was excluded at the variant level.
sample_filters = ""
req, rej = self.required_and_rejected_samples()
if req:
sample_filters += "and base__sample_id in ({})".format(
", ".join(["'{}'".format(sid) for sid in req]))
if rej:
sample_filters += "and base__sample_id not in ({})".format(
", ".join(["'{}'".format(sid) for sid in rej]))
return sample_filters

async def build_base_sql(self, cursor, level, count=False):
bypassfilter = not (self.filter or self.filtersql or self.includesample or self.excludesample)
if level == "variant":
kcol = "base__uid"
if bypassfilter:
Expand Down Expand Up @@ -682,29 +781,20 @@ async def get_filtered_iterator(self, level="variant", conn=None, cursor=None):
", ".join(colnames)
)
else:
sql = "select v.* from " + table + " as v"
if not count:
sql = "select v.* from " + table + " as v"
else:
sql = "select count(v.base__uid) from " + table + " as v"
if bypassfilter == False:
sql += " inner join " + ftable + " as f on v." + kcol + "=f." + kcol
await cursor.execute(sql)
cols = [v[0] for v in cursor.description]
rows = await cursor.fetchall()
return cols, rows

return sql

async def make_filtered_sample_table(self, conn=None, cursor=None):
q = "drop table if exists fsample"
await cursor.execute(q)
await conn.commit()
req = []
rej = []
if "sample" in self.filter:
if "require" in self.filter["sample"]:
req = self.filter["sample"]["require"]
if "reject" in self.filter["sample"]:
rej = self.filter["sample"]["reject"]
if self.includesample is not None:
req = self.includesample
if self.excludesample is not None:
rej = self.excludesample
req, rej = self.required_and_rejected_samples()
if len(req) > 0 or len(rej) > 0:
q = "create table fsample as select distinct base__uid from sample"
if req:
Expand All @@ -721,6 +811,13 @@ async def make_filtered_sample_table(self, conn=None, cursor=None):
else:
return False

def required_and_rejected_samples(self):
sample = self.filter.get("sample", {})
req = sample.get("require", self.includesample or [])
rej = sample.get("reject", self.excludesample or [])

return req, rej

async def make_filter_where(self, conn=None, cursor=None):
q = ""
if len(self.filter) == 0:
Expand Down
3 changes: 3 additions & 0 deletions cravat/cravat_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ async def get_variant_colinfo(self):
if await self.exec_db(self.table_exists, level):
await self.exec_db(self.make_col_info, level)
level = "gene"
if await self.exec_db(self.table_exists, level):
await self.exec_db(self.make_col_info, level)
level = "sample"
if await self.exec_db(self.table_exists, level):
await self.exec_db(self.make_col_info, level)
return self.colinfo
Expand Down
1 change: 0 additions & 1 deletion cravat/cravat_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import json
import sys
import argparse
import imp
import oyaml as yaml
import re
from cravat import admin_util as au
Expand Down
3 changes: 3 additions & 0 deletions cravat/inout.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ def _load_dict(self, d):
self.filterable = bool(d.get("filterable", True))
self.link_format = d.get("link_format")
self.genesummary = d.get("genesummary", False)
self.filter_reagg_function = d.get("filter_reagg_function", None)
self.filter_reagg_function_args = d.get("filter_reagg_function_args", [])
self.filter_reagg_source_column = d.get("filter_reagg_source_column", None)
self.table = d.get("table", False)

def from_row(self, row, order=None):
Expand Down
13 changes: 10 additions & 3 deletions cravat/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from cravat.cravat_report import parser as report_parser
from cravat.vcfanno import vcfanno
import sys
from pathlib import Path
import pathlib

root_p = argparse.ArgumentParser(
description="Open-CRAVAT genomic variant interpreter. https://github.com/KarchinLab/open-cravat"
Expand Down Expand Up @@ -240,9 +240,16 @@
type = int,
help = 'Number of CPU threads to use')
vcfanno_p.add_argument('--temp-dir',
type = Path,
default = Path('temp-vcfanno'),
type = pathlib.Path,
default = pathlib.Path('temp-vcfanno'),
help = 'Temporary directory for working files')
vcfanno_p.add_argument('-o','--output-path',
type = pathlib.Path,
help = 'Output vcf path (gzipped). Defaults to input_path.oc.vcf.gz')
vcfanno_p.add_argument('--chunk-size',
type = int,
default = 10**4,
help = 'Number of lines to annotate in each thread before syncing to disk. Affects performance.')
vcfanno_p.set_defaults(func=vcfanno)

def main():
Expand Down
12 changes: 8 additions & 4 deletions cravat/vcfanno.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,10 @@ def process(self):

def vcfanno(args):
input_path = pathlib.Path(args.input_path)
output_path = pathlib.Path(str(input_path)+'.oc.vcf.gz')
if args.output_path is not None:
output_path = args.output_path
else:
output_path = pathlib.Path(str(input_path)+'.oc.vcf.gz')
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
Expand All @@ -413,7 +416,8 @@ def vcfanno(args):
output_path = str(output_path),
temp_dir = args.temp_dir,
processors = args.threads if args.threads else mp.cpu_count(),
chunk_size=10**4,
chunk_log_frequency=50,
annotators=args.annotators)
chunk_size= args.chunk_size,
chunk_log_frequency = 50,
annotators = args.annotators,
)
anno.process()
Loading

0 comments on commit 176e4ad

Please sign in to comment.