diff --git a/src/directsketch.rs b/src/directsketch.rs index 2fdb220..0d0e298 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -825,7 +825,6 @@ pub async fn gbsketch( batch_size: u32, output_sigs: Option, ) -> Result<(), anyhow::Error> { - // if sig output provided but doesn't end in zip, bail let batch_size = batch_size as usize; let mut batch_index = 1; let mut name_params_map: HashMap> = HashMap::new(); @@ -1060,16 +1059,38 @@ pub async fn urlsketch( output_sigs: Option, failed_checksums_csv: Option, ) -> Result<(), anyhow::Error> { - // if sig output provided but doesn't end in zip, bail let batch_size = batch_size as usize; + let mut batch_index = 1; + let mut name_params_map: HashMap> = HashMap::new(); + let mut filter = false; if let Some(ref output_sigs) = output_sigs { - if Path::new(&output_sigs) - .extension() - .map_or(true, |ext| ext != "zip") - { + // Create outpath from output_sigs + let outpath = PathBuf::from(output_sigs); + + // Check if the extension is "zip" + if outpath.extension().map_or(true, |ext| ext != "zip") { bail!("Output must be a zip file."); } + // find and read any existing sigs + let (existing_batches, max_existing_batch_index) = + find_existing_zip_batches(&outpath).await?; + // Check if there are any existing batches to process + if !existing_batches.is_empty() { + let existing_sigs = MultiCollection::from_zipfiles(&existing_batches)?; + name_params_map = existing_sigs.build_params_hashmap(); + + batch_index = max_existing_batch_index + 1; + eprintln!( + "Found {} existing zip batches. Starting new sig writing at batch {}", + max_existing_batch_index, batch_index + ); + filter = true; + } else { + // No existing batches, skipping signature filtering + eprintln!("No existing signature batches found; building all signatures."); + } } + // set up fasta download path let download_path = PathBuf::from(fasta_location); if !download_path.exists() { @@ -1087,7 +1108,6 @@ pub async fn urlsketch( // Set up collector/writing tasks let mut handles = Vec::new(); - let batch_index = 1; let sig_handle = zipwriter_handle( recv_sigs, output_sigs, @@ -1168,6 +1188,18 @@ pub async fn urlsketch( for (i, accinfo) in accession_info.into_iter().enumerate() { py.check_signals()?; // If interrupted, return an Err automatically + let mut dna_sigs = dna_template_collection.clone(); + let mut prot_sigs = prot_template_collection.clone(); + + // filter template sigs based on existing sigs + if filter { + if let Some(existing_paramset) = name_params_map.get(&accinfo.name) { + // If the key exists, filter template sigs + dna_sigs.filter(&existing_paramset); + prot_sigs.filter(&existing_paramset); + } + } + let semaphore_clone = Arc::clone(&semaphore); let client_clone = Arc::clone(&client); let send_sigs = send_sigs.clone(); @@ -1176,9 +1208,6 @@ pub async fn urlsketch( let download_path_clone = download_path.clone(); // Clone the path for each task let send_errors = error_sender.clone(); - let mut dna_sigs = dna_template_collection.clone(); - let mut prot_sigs = prot_template_collection.clone(); - tokio::spawn(async move { let _permit = semaphore_clone.acquire().await; // progress report when the permit is available and processing begins diff --git a/tests/test_urlsketch.py b/tests/test_urlsketch.py index 51160d2..f5de3ff 100644 --- a/tests/test_urlsketch.py +++ b/tests/test_urlsketch.py @@ -481,3 +481,104 @@ def test_urlsketch_md5sum_mismatch_no_checksum_file(runtmp, capfd): assert md5sum == "b1234567" assert download_filename == "GCA_000175535.1_genomic.urlsketch.fna.gz" assert url == "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/175/535/GCA_000175535.1_ASM17553v1/GCA_000175535.1_ASM17553v1_genomic.fna.gz" + + +def test_urlsketch_simple_batched(runtmp, capfd): + acc_csv = get_test_data('acc-url.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + out2 = runtmp.output('simple.2.zip') + out3 = runtmp.output('simple.3.zip') + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss3 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + runtmp.sourmash('scripts', 'urlsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert os.path.exists(out2) + assert os.path.exists(out3) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + + expected_siginfo = { + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + (ss2.name, ss2.md5sum(), ss2.minhash.moltype), + (ss3.name, ss3.md5sum(), ss3.minhash.moltype) + } + # Collect all signatures from the output zip files + all_sigs = [] + + for out_file in [out1, out2, out3]: + idx = sourmash.load_file_as_index(out_file) + sigs = list(idx.signatures()) + assert len(sigs) == 1 # We expect exactly 1 signature per batch + all_sigs.append(sigs[0]) + + loaded_signatures = {(sig.name, sig.md5sum(), sig.minhash.moltype) for sig in all_sigs} + assert loaded_signatures == expected_siginfo, f"Loaded sigs: {loaded_signatures}, expected: {expected_siginfo}" + + +def test_urlsketch_simple_batch_restart(runtmp, capfd): + acc_csv = get_test_data('acc-url.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + out2 = runtmp.output('simple.2.zip') + out3 = runtmp.output('simple.3.zip') + + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss3 = sourmash.load_one_signature(sig2, ksize=21) + ss4 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + # first, cat sig2 into an output file that will trick gbsketch into thinking it's a prior batch + runtmp.sourmash('sig', 'cat', sig2, '-o', out1) + assert os.path.exists(out1) + + # now run. This should mean that out1 is sig2, out2 is sig1, out3 is sig3 + runtmp.sourmash('scripts', 'urlsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000,abund", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert os.path.exists(out2) + assert os.path.exists(out3) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + + expected_siginfo = { + (ss2.name, ss2.md5sum(), ss2.minhash.moltype), + (ss2.name, ss3.md5sum(), ss3.minhash.moltype), # ss2 name b/c thats how it is in acc-url.csv + (ss4.name, ss4.md5sum(), ss4.minhash.moltype), + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + } + + all_siginfo = set() + for out_file in [out1, out2, out3]: + idx = sourmash.load_file_as_index(out_file) + sigs = list(idx.signatures()) + for sig in sigs: + all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype)) + + # Verify that the loaded signatures match the expected signatures, order-independent + assert all_siginfo == expected_siginfo, f"Loaded sigs: {all_siginfo}, expected: {expected_siginfo}"