Skip to content

Commit

Permalink
add batch logic to urlsketch
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed Oct 1, 2024
1 parent fef279d commit 367b38d
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 10 deletions.
49 changes: 39 additions & 10 deletions src/directsketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,6 @@ pub async fn gbsketch(
batch_size: u32,
output_sigs: Option<String>,
) -> Result<(), anyhow::Error> {
// if sig output provided but doesn't end in zip, bail
let batch_size = batch_size as usize;
let mut batch_index = 1;
let mut name_params_map: HashMap<String, HashSet<u64>> = HashMap::new();
Expand Down Expand Up @@ -1060,16 +1059,38 @@ pub async fn urlsketch(
output_sigs: Option<String>,
failed_checksums_csv: Option<String>,
) -> Result<(), anyhow::Error> {
// if sig output provided but doesn't end in zip, bail
let batch_size = batch_size as usize;
let mut batch_index = 1;
let mut name_params_map: HashMap<String, HashSet<u64>> = HashMap::new();
let mut filter = false;
if let Some(ref output_sigs) = output_sigs {
if Path::new(&output_sigs)
.extension()
.map_or(true, |ext| ext != "zip")
{
// Create outpath from output_sigs
let outpath = PathBuf::from(output_sigs);

// Check if the extension is "zip"
if outpath.extension().map_or(true, |ext| ext != "zip") {
bail!("Output must be a zip file.");
}
// find and read any existing sigs
let (existing_batches, max_existing_batch_index) =
find_existing_zip_batches(&outpath).await?;
// Check if there are any existing batches to process
if !existing_batches.is_empty() {
let existing_sigs = MultiCollection::from_zipfiles(&existing_batches)?;
name_params_map = existing_sigs.build_params_hashmap();

batch_index = max_existing_batch_index + 1;
eprintln!(
"Found {} existing zip batches. Starting new sig writing at batch {}",
max_existing_batch_index, batch_index
);
filter = true;
} else {
// No existing batches, skipping signature filtering
eprintln!("No existing signature batches found; building all signatures.");
}
}

// set up fasta download path
let download_path = PathBuf::from(fasta_location);
if !download_path.exists() {
Expand All @@ -1087,7 +1108,6 @@ pub async fn urlsketch(
// Set up collector/writing tasks
let mut handles = Vec::new();

let batch_index = 1;
let sig_handle = zipwriter_handle(
recv_sigs,
output_sigs,
Expand Down Expand Up @@ -1168,6 +1188,18 @@ pub async fn urlsketch(

for (i, accinfo) in accession_info.into_iter().enumerate() {
py.check_signals()?; // If interrupted, return an Err automatically
let mut dna_sigs = dna_template_collection.clone();
let mut prot_sigs = prot_template_collection.clone();

// filter template sigs based on existing sigs
if filter {
if let Some(existing_paramset) = name_params_map.get(&accinfo.name) {
// If the key exists, filter template sigs
dna_sigs.filter(&existing_paramset);
prot_sigs.filter(&existing_paramset);
}
}

let semaphore_clone = Arc::clone(&semaphore);
let client_clone = Arc::clone(&client);
let send_sigs = send_sigs.clone();
Expand All @@ -1176,9 +1208,6 @@ pub async fn urlsketch(
let download_path_clone = download_path.clone(); // Clone the path for each task
let send_errors = error_sender.clone();

let mut dna_sigs = dna_template_collection.clone();
let mut prot_sigs = prot_template_collection.clone();

tokio::spawn(async move {
let _permit = semaphore_clone.acquire().await;
// progress report when the permit is available and processing begins
Expand Down
101 changes: 101 additions & 0 deletions tests/test_urlsketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,104 @@ def test_urlsketch_md5sum_mismatch_no_checksum_file(runtmp, capfd):
assert md5sum == "b1234567"
assert download_filename == "GCA_000175535.1_genomic.urlsketch.fna.gz"
assert url == "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/175/535/GCA_000175535.1_ASM17553v1/GCA_000175535.1_ASM17553v1_genomic.fna.gz"


def test_urlsketch_simple_batched(runtmp, capfd):
acc_csv = get_test_data('acc-url.csv')
output = runtmp.output('simple.zip')
failed = runtmp.output('failed.csv')
ch_fail = runtmp.output('checksum_dl_failed.csv')

out1 = runtmp.output('simple.1.zip')
out2 = runtmp.output('simple.2.zip')
out3 = runtmp.output('simple.3.zip')

sig1 = get_test_data('GCA_000175535.1.sig.gz')
sig2 = get_test_data('GCA_000961135.2.sig.gz')
sig3 = get_test_data('GCA_000961135.2.protein.sig.gz')
ss1 = sourmash.load_one_signature(sig1, ksize=31)
ss2 = sourmash.load_one_signature(sig2, ksize=31)
ss3 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein')

runtmp.sourmash('scripts', 'urlsketch', acc_csv, '-o', output,
'--failed', failed, '-r', '1', '--checksum-fail', ch_fail,
'--param-str', "dna,k=31,scaled=1000", '-p', "protein,k=10,scaled=200",
'--batch-size', '1')

assert os.path.exists(out1)
assert os.path.exists(out2)
assert os.path.exists(out3)
assert not os.path.exists(output) # for now, orig output file should be empty.
captured = capfd.readouterr()
print(captured.err)

expected_siginfo = {
(ss1.name, ss1.md5sum(), ss1.minhash.moltype),
(ss2.name, ss2.md5sum(), ss2.minhash.moltype),
(ss3.name, ss3.md5sum(), ss3.minhash.moltype)
}
# Collect all signatures from the output zip files
all_sigs = []

for out_file in [out1, out2, out3]:
idx = sourmash.load_file_as_index(out_file)
sigs = list(idx.signatures())
assert len(sigs) == 1 # We expect exactly 1 signature per batch
all_sigs.append(sigs[0])

loaded_signatures = {(sig.name, sig.md5sum(), sig.minhash.moltype) for sig in all_sigs}
assert loaded_signatures == expected_siginfo, f"Loaded sigs: {loaded_signatures}, expected: {expected_siginfo}"


def test_urlsketch_simple_batch_restart(runtmp, capfd):
acc_csv = get_test_data('acc-url.csv')
output = runtmp.output('simple.zip')
failed = runtmp.output('failed.csv')
ch_fail = runtmp.output('checksum_dl_failed.csv')

out1 = runtmp.output('simple.1.zip')
out2 = runtmp.output('simple.2.zip')
out3 = runtmp.output('simple.3.zip')


sig1 = get_test_data('GCA_000175535.1.sig.gz')
sig2 = get_test_data('GCA_000961135.2.sig.gz')
sig3 = get_test_data('GCA_000961135.2.protein.sig.gz')
ss1 = sourmash.load_one_signature(sig1, ksize=31)
ss2 = sourmash.load_one_signature(sig2, ksize=31)
ss3 = sourmash.load_one_signature(sig2, ksize=21)
ss4 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein')

# first, cat sig2 into an output file that will trick gbsketch into thinking it's a prior batch
runtmp.sourmash('sig', 'cat', sig2, '-o', out1)
assert os.path.exists(out1)

# now run. This should mean that out1 is sig2, out2 is sig1, out3 is sig3
runtmp.sourmash('scripts', 'urlsketch', acc_csv, '-o', output,
'--failed', failed, '-r', '1', '--checksum-fail', ch_fail,
'--param-str', "dna,k=31,scaled=1000,abund", '-p', "protein,k=10,scaled=200",
'--batch-size', '1')

assert os.path.exists(out1)
assert os.path.exists(out2)
assert os.path.exists(out3)
assert not os.path.exists(output) # for now, orig output file should be empty.
captured = capfd.readouterr()
print(captured.err)

expected_siginfo = {
(ss2.name, ss2.md5sum(), ss2.minhash.moltype),
(ss2.name, ss3.md5sum(), ss3.minhash.moltype), # ss2 name b/c thats how it is in acc-url.csv
(ss4.name, ss4.md5sum(), ss4.minhash.moltype),
(ss1.name, ss1.md5sum(), ss1.minhash.moltype),
}

all_siginfo = set()
for out_file in [out1, out2, out3]:
idx = sourmash.load_file_as_index(out_file)
sigs = list(idx.signatures())
for sig in sigs:
all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype))

# Verify that the loaded signatures match the expected signatures, order-independent
assert all_siginfo == expected_siginfo, f"Loaded sigs: {all_siginfo}, expected: {expected_siginfo}"

0 comments on commit 367b38d

Please sign in to comment.