Skip to content

Commit

Permalink
[feat] Implement min file size for use in models
Browse files Browse the repository at this point in the history
Changelog:
- `snekmer.utils`: Helper functions added to check file length
- File size is now checked to meet minimum threshold for K-fold cross-validation
- If min file size is not met, file is excluded from the pipeline and a warning is printed to the user to notify them of file exclusion
- Note: warning is not active during a dry run due to the behavior of the `onstart` directive
  • Loading branch information
christinehc committed Dec 8, 2022
1 parent 7eb479a commit 619f554
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
12 changes: 12 additions & 0 deletions snekmer/rules/model.smk
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ unzipped = [
fa.rstrip(".gz")
for fa, ext in product(input_files, config["input_file_exts"])
if fa.rstrip(".gz").endswith(f".{ext}")
and skm.utils.check_n_seqs(fa, config["model"]["cv"], show_warning=False)
]


# map extensions to basename (basename.ext.gz -> {basename: ext})
UZ_MAP = {
skm.utils.split_file_ext(f)[0]: skm.utils.split_file_ext(f)[1] for f in zipped
Expand All @@ -62,6 +64,11 @@ FA_MAP = {
skm.utils.split_file_ext(f)[0]: skm.utils.split_file_ext(f)[1] for f in unzipped
}

# final file map: checks that files are large enough for model building
# FA_MAP = {
# k: v for k, v in f_map.items() if skm.utils.check_n_seqs(k, config["model"]["cv"])
# }

# get unzipped filenames
UZS = [f"{f}.{ext}" for f, ext in UZ_MAP.items()]

Expand All @@ -82,6 +89,11 @@ out_dir = skm.io.define_output_dir(
config["alphabet"], config["k"], nested=config["nested_output"]
)

# show warnings if files excluded
onstart:
[
skm.utils.check_n_seqs(fa, config["model"]["cv"], show_warning=True) for fa in input_files
]

# define output files to be created by snekmer
rule all:
Expand Down
4 changes: 1 addition & 3 deletions snekmer/scripts/model_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@
# )
# print(lookup)
with open(snakemake.input.matrix, "rb") as f:
matrix = pickle.load(f)

data = matrix
data = pickle.load(f)

# load all input data and encode rule-wide variables
# data = pd.read_csv(input.data)
Expand Down
51 changes: 51 additions & 0 deletions snekmer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import collections.abc
import datetime
import re

from os.path import basename, splitext
from typing import Any, List, Optional, Tuple, Union

Expand Down Expand Up @@ -200,3 +201,53 @@ def to_feature_matrix(
length_array = np.ones(len(array))
array = [np.array(a) / length for a, length in zip(array, length_array)]
return np.asarray(array)


def count_n_seqs(filename: str) -> int:
"""Count number of sequences in a file.
Parameters
----------
filename : str
/path/to/sequence/file.fasta
Other text formatted sequence files (.faa, etc.) also work.
Returns
-------
int
Number of sequences contained within the input file.
"""
return len([1 for line in open(filename) if line.startswith(">")])


def check_n_seqs(filename: str, k: int, show_warning: bool = True) -> bool:
"""Check that a file contains at least k sequences.
Parameters
----------
filename : str
/path/to/sequence/file.fasta
Other text formatted sequence files (.faa, etc.) also work.
k : int
Minimum threshold.
show_warning : bool, optional
When True, if len(file) < k, a warning is displayed;
by default True.
Returns
-------
bool
True if len(file) < k; False otherwise.
"""
n_seqs = len([1 for line in open(filename) if line.startswith(">")])
if (n_seqs < k) and (show_warning):
print(
f"\nWARNING: {filename} contains an insufficient"
" number of sequences for model cross-validation"
" and will thus be excluded from Snekmer modeling."
f" ({k} folds specified in config; {n_seqs}"
" sequence(s) detected.)\n"
)
return n_seqs >= k

0 comments on commit 619f554

Please sign in to comment.