diff --git a/snekmer/rules/model.smk b/snekmer/rules/model.smk index 317d4ce..3edbf40 100644 --- a/snekmer/rules/model.smk +++ b/snekmer/rules/model.smk @@ -52,8 +52,10 @@ unzipped = [ fa.rstrip(".gz") for fa, ext in product(input_files, config["input_file_exts"]) if fa.rstrip(".gz").endswith(f".{ext}") + and skm.utils.check_n_seqs(fa, config["model"]["cv"], show_warning=False) ] + # map extensions to basename (basename.ext.gz -> {basename: ext}) UZ_MAP = { skm.utils.split_file_ext(f)[0]: skm.utils.split_file_ext(f)[1] for f in zipped @@ -62,6 +64,11 @@ FA_MAP = { skm.utils.split_file_ext(f)[0]: skm.utils.split_file_ext(f)[1] for f in unzipped } +# final file map: checks that files are large enough for model building +# FA_MAP = { +# k: v for k, v in f_map.items() if skm.utils.check_n_seqs(k, config["model"]["cv"]) +# } + # get unzipped filenames UZS = [f"{f}.{ext}" for f, ext in UZ_MAP.items()] @@ -82,6 +89,11 @@ out_dir = skm.io.define_output_dir( config["alphabet"], config["k"], nested=config["nested_output"] ) +# show warnings if files excluded +onstart: + [ + skm.utils.check_n_seqs(fa, config["model"]["cv"], show_warning=True) for fa in input_files + ] # define output files to be created by snekmer rule all: diff --git a/snekmer/scripts/model_model.py b/snekmer/scripts/model_model.py index 0e15ac8..0d8886e 100644 --- a/snekmer/scripts/model_model.py +++ b/snekmer/scripts/model_model.py @@ -40,9 +40,7 @@ # ) # print(lookup) with open(snakemake.input.matrix, "rb") as f: - matrix = pickle.load(f) - -data = matrix + data = pickle.load(f) # load all input data and encode rule-wide variables # data = pd.read_csv(input.data) diff --git a/snekmer/utils.py b/snekmer/utils.py index 42e9f40..6f3a333 100644 --- a/snekmer/utils.py +++ b/snekmer/utils.py @@ -7,6 +7,7 @@ import collections.abc import datetime import re + from os.path import basename, splitext from typing import Any, List, Optional, Tuple, Union @@ -200,3 +201,53 @@ def to_feature_matrix( length_array = np.ones(len(array)) array = [np.array(a) / length for a, length in zip(array, length_array)] return np.asarray(array) + + +def count_n_seqs(filename: str) -> int: + """Count number of sequences in a file. + + Parameters + ---------- + filename : str + /path/to/sequence/file.fasta + Other text formatted sequence files (.faa, etc.) also work. + + Returns + ------- + int + Number of sequences contained within the input file. + + """ + return len([1 for line in open(filename) if line.startswith(">")]) + + +def check_n_seqs(filename: str, k: int, show_warning: bool = True) -> bool: + """Check that a file contains at least k sequences. + + Parameters + ---------- + filename : str + /path/to/sequence/file.fasta + Other text formatted sequence files (.faa, etc.) also work. + k : int + Minimum threshold. + show_warning : bool, optional + When True, if len(file) < k, a warning is displayed; + by default True. + + Returns + ------- + bool + True if len(file) < k; False otherwise. + + """ + n_seqs = len([1 for line in open(filename) if line.startswith(">")]) + if (n_seqs < k) and (show_warning): + print( + f"\nWARNING: {filename} contains an insufficient" + " number of sequences for model cross-validation" + " and will thus be excluded from Snekmer modeling." + f" ({k} folds specified in config; {n_seqs}" + " sequence(s) detected.)\n" + ) + return n_seqs >= k