diff --git a/src/directsketch.rs b/src/directsketch.rs index 08539ff..2fd4862 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -7,7 +7,6 @@ use sourmash::collection::Collection; use std::cmp::max; use std::collections::HashMap; use std::fs::{self, create_dir_all}; -use std::panic; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::fs::File; @@ -388,7 +387,7 @@ async fn dl_sketch_url( accinfo: AccessionData, location: &PathBuf, retry: Option, - _keep_fastas: bool, + keep_fastas: bool, mut sigs: BuildCollection, _genomes_only: bool, _proteomes_only: bool, @@ -401,64 +400,86 @@ async fn dl_sketch_url( let name = accinfo.name; let accession = accinfo.accession; - let url = accinfo.url; - let expected_md5 = accinfo.expected_md5sum; let download_filename = accinfo.download_filename; + let filename = download_filename.clone().unwrap_or("".to_string()); let moltype = accinfo.moltype; - match download_with_retry(client, &url, expected_md5.as_deref(), retry_count).await { - Ok(data) => { - // check keep_fastas instead?? - if let Some(ref download_filename) = download_filename { - let path = location.join(download_filename); - fs::write(path, &data).context("Failed to write data to file")?; - } - if !download_only { - let filename = download_filename.clone().unwrap_or("".to_string()); - // sketch data - - match moltype { - InputMolType::Dna => { - sigs.build_sigs_from_data(data, "DNA", name.clone(), filename.clone())?; - } - InputMolType::Protein => { - sigs.build_sigs_from_data(data, "protein", name.clone(), filename.clone())?; + for (url, expected_md5) in accinfo.url_md5_pairs { + match download_with_retry(client, &url, expected_md5.as_deref(), retry_count).await { + Ok(data) => { + // if keep_fastas, write file to disk + if keep_fastas { + // note, if multiple urls are provided, this will append to the same file + if let Some(ref download_filename) = download_filename { + let path = location.join(download_filename); + // Open the file in append mode (or create it if it doesn't exist) + let mut file = tokio::fs::OpenOptions::new() + .create(true) // Create the file if it doesn't exist + .append(true) // Append to the file + .open(&path) + .await + .context("Failed to open file in append mode")?; + file.write_all(&data) + .await + .context("Failed to write data to file")?; + // std::fs::write(&path, &data).context("Failed to write data to file")?; } - }; + } + if !download_only { + // sketch data + + match moltype { + InputMolType::Dna => { + sigs.build_sigs_from_data(data, "DNA", name.clone(), filename.clone())?; + } + InputMolType::Protein => { + sigs.build_sigs_from_data( + data, + "protein", + name.clone(), + filename.clone(), + )?; + } + }; + } } - } - Err(err) => { - let error_message = err.to_string(); - // did we have a checksum error or a download error? - // here --> keep track of accession errors + filetype - if error_message.contains("MD5 hash does not match") { - let checksum_mismatch: FailedChecksum = FailedChecksum { - accession: accession.clone(), - name: name.clone(), - moltype: moltype.to_string(), - md5sum_url: None, - download_filename, - url: Some(url.clone()), - expected_md5sum: expected_md5.clone(), - reason: error_message.clone(), - }; - checksum_failures.push(checksum_mismatch); - sigs = empty_coll; - } else { - let failed_download = FailedDownload { - accession: accession.clone(), - name: name.clone(), - moltype: moltype.to_string(), - md5sum: expected_md5.map(|x| x.to_string()), - download_filename, - url: Some(url), - }; - download_failures.push(failed_download); - sigs = empty_coll; + Err(err) => { + let error_message = err.to_string(); + // did we have a checksum error or a download error? + // here --> keep track of accession errors + filetype + if error_message.contains("MD5 hash does not match") { + let checksum_mismatch: FailedChecksum = FailedChecksum { + accession: accession.clone(), + name: name.clone(), + moltype: moltype.to_string(), + md5sum_url: None, + download_filename: download_filename.clone(), + url: Some(url.clone()), + expected_md5sum: expected_md5.clone(), + reason: error_message.clone(), + }; + checksum_failures.push(checksum_mismatch); + } else { + let failed_download = FailedDownload { + accession: accession.clone(), + name: name.clone(), + moltype: moltype.to_string(), + md5sum: expected_md5.map(|x| x.to_string()), + download_filename, + url: Some(url), + }; + download_failures.push(failed_download); + // Clear signatures and return immediately on failure + sigs = empty_coll; + return Ok((sigs, download_failures, checksum_failures)); + } } } } + // Update signature info + sigs.update_info(name, filename); + Ok((sigs, download_failures, checksum_failures)) } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1e774a4..7b1116e 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -85,15 +85,15 @@ impl GenBankFileType { } } } + #[allow(dead_code)] #[derive(Clone)] pub struct AccessionData { pub accession: String, pub name: String, pub moltype: InputMolType, - pub url: reqwest::Url, - pub expected_md5sum: Option, - pub download_filename: Option, // need to require this if --keep-fastas are used + pub url_md5_pairs: Vec<(reqwest::Url, Option)>, + pub download_filename: Option, // Need to require this if --keep-fastas are used } #[derive(Clone)] @@ -227,36 +227,62 @@ pub fn load_accession_info( .ok_or_else(|| anyhow!("Missing 'moltype' field"))? .parse::() .map_err(|_| anyhow!("Invalid 'moltype' value"))?; - let expected_md5sum = record.get(3).map(|s| s.to_string()); + + // Parse URLs + let url_field = record + .get(5) + .ok_or_else(|| anyhow!("Missing 'url' field"))?; + + let urls: Vec = url_field + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .filter_map(|s| reqwest::Url::parse(s).ok()) + .collect(); + + if urls.is_empty() { + return Err(anyhow!("No valid URLs found in 'url' field")); + } + + // Parse MD5 sums and build url_md5_pairs + let md5sum_field = record.get(3).unwrap_or(""); + let url_md5_pairs: Vec<(reqwest::Url, Option)> = { + if !md5sum_field.trim().is_empty() { + let parsed_md5s: Vec = md5sum_field + .split(';') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + if parsed_md5s.len() != urls.len() { + return Err(anyhow!( + "Number of MD5 sums ({}) does not match the number of URLs ({}) for accession '{}'", + parsed_md5s.len(), + urls.len(), + acc + )); + } + + // Pair URLs with corresponding MD5 sums + urls.into_iter() + .zip(parsed_md5s.into_iter().map(Some)) + .collect() + } else { + // If no MD5 sums are provided, pair URLs with None + urls.into_iter().zip(std::iter::repeat(None)).collect() + } + }; + let download_filename = record.get(4).map(|s| s.to_string()); if keep_fasta && download_filename.is_none() { return Err(anyhow!("Missing 'download_filename' field")); } - let url = record - .get(5) - .ok_or_else(|| anyhow!("Missing 'url' field"))? - .split(',') - .filter_map(|s| { - if s.starts_with("http://") || s.starts_with("https://") || s.starts_with("ftp://") - { - reqwest::Url::parse(s).ok() - } else { - None - } - }) - .next() - .ok_or_else(|| anyhow!("Invalid 'url' value"))?; - // count entries with url and md5sum - if expected_md5sum.is_some() { - md5sum_count += 1; - } // store accession data results.push(AccessionData { accession: acc, name, moltype, - url, - expected_md5sum, + url_md5_pairs, download_filename, }); } diff --git a/tests/sourmash_tst_utils.py b/tests/sourmash_tst_utils.py index 4bbc87f..aaaa37d 100644 --- a/tests/sourmash_tst_utils.py +++ b/tests/sourmash_tst_utils.py @@ -7,8 +7,7 @@ import collections import pprint -import pkg_resources -from pkg_resources import Requirement, resource_filename, ResolutionError +import importlib.metadata import traceback from io import open # pylint: disable=redefined-builtin from io import StringIO @@ -43,23 +42,13 @@ def _runscript(scriptname): namespace = {"__name__": "__main__"} namespace['sys'] = globals()['sys'] - try: - pkg_resources.load_entry_point("sourmash", 'console_scripts', scriptname)() - return 0 - except pkg_resources.ResolutionError: - pass - - path = scriptpath() - - scriptfile = os.path.join(path, scriptname) - if os.path.isfile(scriptfile): - if os.path.isfile(scriptfile): - exec( # pylint: disable=exec-used - compile(open(scriptfile).read(), scriptfile, 'exec'), - namespace) - return 0 - - return -1 + entry_points = importlib.metadata.entry_points( + group="console_scripts", name="sourmash" + ) + assert len(entry_points) == 1 + smash_cli = tuple(entry_points)[0].load() + smash_cli() + return 0 ScriptResults = collections.namedtuple('ScriptResults', diff --git a/tests/test-data/acc-merged.csv b/tests/test-data/acc-merged.csv new file mode 100644 index 0000000..7684b91 --- /dev/null +++ b/tests/test-data/acc-merged.csv @@ -0,0 +1,3 @@ +accession,name,moltype,md5sum,download_filename,url +both,both name,dna,,both.urlsketch.fna.gz,https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/961/135/GCA_000961135.2_ASM96113v2/GCA_000961135.2_ASM96113v2_genomic.fna.gz; https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/175/535/GCA_000175535.1_ASM17553v1/GCA_000175535.1_ASM17553v1_genomic.fna.gz + diff --git a/tests/test_urlsketch.py b/tests/test_urlsketch.py index 20b1548..3b418b6 100644 --- a/tests/test_urlsketch.py +++ b/tests/test_urlsketch.py @@ -746,3 +746,37 @@ def test_urlsketch_simple_skipmer(runtmp, capfd): assert ( siginfo["molecule"] == expected["moltype"] ), f"Moltype mismatch: {siginfo['molecule']}" + + +def test_urlsketch_simple_merged(runtmp): + acc_csv = get_test_data('acc-merged.csv') + output = runtmp.output('merged.zip') + failed = runtmp.output('failed.csv') + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + merged_sig = runtmp.output("sigmerge.zip") + + # create merged signature + runtmp.sourmash("sig", "merge", "-k", "31", sig1, sig2, "--set-name", "both name", '-o', merged_sig) + msigidx = sourmash.load_file_as_index(merged_sig) + msig = list(msigidx.signatures())[0] + print(msig.name) + + runtmp.sourmash('scripts', 'urlsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', + '--param-str', "dna,k=31,scaled=1000") + + assert os.path.exists(output) + assert not runtmp.last_result.out # stdout should be empty + + idx = sourmash.load_file_as_index(output) + sigs = list(idx.signatures()) + + assert len(sigs) == 1 + sig = sigs[0] + assert sig.name == msig.name == "both name" + print(msig.md5sum()) + assert sig.md5sum() == msig.md5sum() + assert sig.minhash.moltype == msig.minhash.moltype == "DNA" + assert os.path.exists(failed)