Skip to content

Commit

Permalink
allow merging urls into single sketch
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed Dec 18, 2024
1 parent 927916c commit 85b2212
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 94 deletions.
123 changes: 72 additions & 51 deletions src/directsketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use sourmash::collection::Collection;
use std::cmp::max;
use std::collections::HashMap;
use std::fs::{self, create_dir_all};
use std::panic;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::fs::File;
Expand Down Expand Up @@ -388,7 +387,7 @@ async fn dl_sketch_url(
accinfo: AccessionData,
location: &PathBuf,
retry: Option<u32>,
_keep_fastas: bool,
keep_fastas: bool,
mut sigs: BuildCollection,
_genomes_only: bool,
_proteomes_only: bool,
Expand All @@ -401,64 +400,86 @@ async fn dl_sketch_url(

let name = accinfo.name;
let accession = accinfo.accession;
let url = accinfo.url;
let expected_md5 = accinfo.expected_md5sum;
let download_filename = accinfo.download_filename;
let filename = download_filename.clone().unwrap_or("".to_string());
let moltype = accinfo.moltype;

match download_with_retry(client, &url, expected_md5.as_deref(), retry_count).await {
Ok(data) => {
// check keep_fastas instead??
if let Some(ref download_filename) = download_filename {
let path = location.join(download_filename);
fs::write(path, &data).context("Failed to write data to file")?;
}
if !download_only {
let filename = download_filename.clone().unwrap_or("".to_string());
// sketch data

match moltype {
InputMolType::Dna => {
sigs.build_sigs_from_data(data, "DNA", name.clone(), filename.clone())?;
}
InputMolType::Protein => {
sigs.build_sigs_from_data(data, "protein", name.clone(), filename.clone())?;
for (url, expected_md5) in accinfo.url_md5_pairs {
match download_with_retry(client, &url, expected_md5.as_deref(), retry_count).await {
Ok(data) => {
// if keep_fastas, write file to disk
if keep_fastas {
// note, if multiple urls are provided, this will append to the same file
if let Some(ref download_filename) = download_filename {
let path = location.join(download_filename);
// Open the file in append mode (or create it if it doesn't exist)
let mut file = tokio::fs::OpenOptions::new()
.create(true) // Create the file if it doesn't exist
.append(true) // Append to the file
.open(&path)
.await
.context("Failed to open file in append mode")?;
file.write_all(&data)
.await
.context("Failed to write data to file")?;
// std::fs::write(&path, &data).context("Failed to write data to file")?;
}
};
}
if !download_only {
// sketch data

match moltype {
InputMolType::Dna => {
sigs.build_sigs_from_data(data, "DNA", name.clone(), filename.clone())?;
}
InputMolType::Protein => {
sigs.build_sigs_from_data(
data,
"protein",
name.clone(),
filename.clone(),
)?;
}
};
}
}
}
Err(err) => {
let error_message = err.to_string();
// did we have a checksum error or a download error?
// here --> keep track of accession errors + filetype
if error_message.contains("MD5 hash does not match") {
let checksum_mismatch: FailedChecksum = FailedChecksum {
accession: accession.clone(),
name: name.clone(),
moltype: moltype.to_string(),
md5sum_url: None,
download_filename,
url: Some(url.clone()),
expected_md5sum: expected_md5.clone(),
reason: error_message.clone(),
};
checksum_failures.push(checksum_mismatch);
sigs = empty_coll;
} else {
let failed_download = FailedDownload {
accession: accession.clone(),
name: name.clone(),
moltype: moltype.to_string(),
md5sum: expected_md5.map(|x| x.to_string()),
download_filename,
url: Some(url),
};
download_failures.push(failed_download);
sigs = empty_coll;
Err(err) => {
let error_message = err.to_string();
// did we have a checksum error or a download error?
// here --> keep track of accession errors + filetype
if error_message.contains("MD5 hash does not match") {
let checksum_mismatch: FailedChecksum = FailedChecksum {
accession: accession.clone(),
name: name.clone(),
moltype: moltype.to_string(),
md5sum_url: None,
download_filename: download_filename.clone(),
url: Some(url.clone()),
expected_md5sum: expected_md5.clone(),
reason: error_message.clone(),
};
checksum_failures.push(checksum_mismatch);
} else {
let failed_download = FailedDownload {
accession: accession.clone(),
name: name.clone(),
moltype: moltype.to_string(),
md5sum: expected_md5.map(|x| x.to_string()),
download_filename,
url: Some(url),
};
download_failures.push(failed_download);
// Clear signatures and return immediately on failure
sigs = empty_coll;
return Ok((sigs, download_failures, checksum_failures));
}
}
}
}

// Update signature info
sigs.update_info(name, filename);

Ok((sigs, download_failures, checksum_failures))
}

Expand Down
74 changes: 50 additions & 24 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ impl GenBankFileType {
}
}
}

#[allow(dead_code)]
#[derive(Clone)]
pub struct AccessionData {
pub accession: String,
pub name: String,
pub moltype: InputMolType,
pub url: reqwest::Url,
pub expected_md5sum: Option<String>,
pub download_filename: Option<String>, // need to require this if --keep-fastas are used
pub url_md5_pairs: Vec<(reqwest::Url, Option<String>)>,
pub download_filename: Option<String>, // Need to require this if --keep-fastas are used
}

#[derive(Clone)]
Expand Down Expand Up @@ -227,36 +227,62 @@ pub fn load_accession_info(
.ok_or_else(|| anyhow!("Missing 'moltype' field"))?
.parse::<InputMolType>()
.map_err(|_| anyhow!("Invalid 'moltype' value"))?;
let expected_md5sum = record.get(3).map(|s| s.to_string());

// Parse URLs
let url_field = record
.get(5)
.ok_or_else(|| anyhow!("Missing 'url' field"))?;

let urls: Vec<reqwest::Url> = url_field
.split(';')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.filter_map(|s| reqwest::Url::parse(s).ok())
.collect();

if urls.is_empty() {
return Err(anyhow!("No valid URLs found in 'url' field"));
}

// Parse MD5 sums and build url_md5_pairs
let md5sum_field = record.get(3).unwrap_or("");
let url_md5_pairs: Vec<(reqwest::Url, Option<String>)> = {
if !md5sum_field.trim().is_empty() {
let parsed_md5s: Vec<String> = md5sum_field
.split(';')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();

if parsed_md5s.len() != urls.len() {
return Err(anyhow!(
"Number of MD5 sums ({}) does not match the number of URLs ({}) for accession '{}'",
parsed_md5s.len(),
urls.len(),
acc
));
}

// Pair URLs with corresponding MD5 sums
urls.into_iter()
.zip(parsed_md5s.into_iter().map(Some))
.collect()
} else {
// If no MD5 sums are provided, pair URLs with None
urls.into_iter().zip(std::iter::repeat(None)).collect()
}
};

let download_filename = record.get(4).map(|s| s.to_string());
if keep_fasta && download_filename.is_none() {
return Err(anyhow!("Missing 'download_filename' field"));
}
let url = record
.get(5)
.ok_or_else(|| anyhow!("Missing 'url' field"))?
.split(',')
.filter_map(|s| {
if s.starts_with("http://") || s.starts_with("https://") || s.starts_with("ftp://")
{
reqwest::Url::parse(s).ok()
} else {
None
}
})
.next()
.ok_or_else(|| anyhow!("Invalid 'url' value"))?;
// count entries with url and md5sum
if expected_md5sum.is_some() {
md5sum_count += 1;
}
// store accession data
results.push(AccessionData {
accession: acc,
name,
moltype,
url,
expected_md5sum,
url_md5_pairs,
download_filename,
});
}
Expand Down
27 changes: 8 additions & 19 deletions tests/sourmash_tst_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import collections
import pprint

import pkg_resources
from pkg_resources import Requirement, resource_filename, ResolutionError
import importlib.metadata
import traceback
from io import open # pylint: disable=redefined-builtin
from io import StringIO
Expand Down Expand Up @@ -43,23 +42,13 @@ def _runscript(scriptname):
namespace = {"__name__": "__main__"}
namespace['sys'] = globals()['sys']

try:
pkg_resources.load_entry_point("sourmash", 'console_scripts', scriptname)()
return 0
except pkg_resources.ResolutionError:
pass

path = scriptpath()

scriptfile = os.path.join(path, scriptname)
if os.path.isfile(scriptfile):
if os.path.isfile(scriptfile):
exec( # pylint: disable=exec-used
compile(open(scriptfile).read(), scriptfile, 'exec'),
namespace)
return 0

return -1
entry_points = importlib.metadata.entry_points(
group="console_scripts", name="sourmash"
)
assert len(entry_points) == 1
smash_cli = tuple(entry_points)[0].load()
smash_cli()
return 0


ScriptResults = collections.namedtuple('ScriptResults',
Expand Down
3 changes: 3 additions & 0 deletions tests/test-data/acc-merged.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
accession,name,moltype,md5sum,download_filename,url
both,both name,dna,,both.urlsketch.fna.gz,https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/961/135/GCA_000961135.2_ASM96113v2/GCA_000961135.2_ASM96113v2_genomic.fna.gz; https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/175/535/GCA_000175535.1_ASM17553v1/GCA_000175535.1_ASM17553v1_genomic.fna.gz

34 changes: 34 additions & 0 deletions tests/test_urlsketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,3 +746,37 @@ def test_urlsketch_simple_skipmer(runtmp, capfd):
assert (
siginfo["molecule"] == expected["moltype"]
), f"Moltype mismatch: {siginfo['molecule']}"


def test_urlsketch_simple_merged(runtmp):
acc_csv = get_test_data('acc-merged.csv')
output = runtmp.output('merged.zip')
failed = runtmp.output('failed.csv')

sig1 = get_test_data('GCA_000175535.1.sig.gz')
sig2 = get_test_data('GCA_000961135.2.sig.gz')
merged_sig = runtmp.output("sigmerge.zip")

# create merged signature
runtmp.sourmash("sig", "merge", "-k", "31", sig1, sig2, "--set-name", "both name", '-o', merged_sig)
msigidx = sourmash.load_file_as_index(merged_sig)
msig = list(msigidx.signatures())[0]
print(msig.name)

runtmp.sourmash('scripts', 'urlsketch', acc_csv, '-o', output,
'--failed', failed, '-r', '1',
'--param-str', "dna,k=31,scaled=1000")

assert os.path.exists(output)
assert not runtmp.last_result.out # stdout should be empty

idx = sourmash.load_file_as_index(output)
sigs = list(idx.signatures())

assert len(sigs) == 1
sig = sigs[0]
assert sig.name == msig.name == "both name"
print(msig.md5sum())
assert sig.md5sum() == msig.md5sum()
assert sig.minhash.moltype == msig.minhash.moltype == "DNA"
assert os.path.exists(failed)

0 comments on commit 85b2212

Please sign in to comment.