From 1b19919753a0e5fab4fd0481f6931f0686bdea69 Mon Sep 17 00:00:00 2001 From: Tessa Pierce Ward Date: Mon, 21 Oct 2024 15:42:30 -0700 Subject: [PATCH] MRG: fix bug in zip paths if output provided in current dir (#121) - fixes #118 This restores use of an output zipfile that does not contain an explicit path:`output.zip` was previously failing, but not `/path/to/output.zip`, which is why all tests passed. Since we use tempdirs to run tests, not sure how to test this appropriately...? --- src/directsketch.rs | 52 +++++++++++++++++++++++++++++++++++++++-- tests/test_gbsketch.py | 53 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/src/directsketch.rs b/src/directsketch.rs index 2cb28f2..80ad356 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -489,6 +489,13 @@ async fn dl_sketch_url( Ok((built_sigs, download_failures, checksum_failures)) } +fn get_current_directory() -> Result { + let current_dir = + std::env::current_dir().context("Failed to retrieve the current working directory")?; + PathBuf::try_from(current_dir) + .map_err(|_| anyhow::anyhow!("Current directory is not valid UTF-8")) +} + // Load existing batch files into MultiCollection, skipping corrupt files async fn load_existing_zip_batches(outpath: &PathBuf) -> Result<(MultiCollection, usize)> { // Remove the .zip extension to get the base name @@ -505,15 +512,56 @@ async fn load_existing_zip_batches(outpath: &PathBuf) -> Result<(MultiCollection let mut collections = Vec::new(); let mut highest_batch = 0; // Track the highest batch number - // Read the directory containing the outpath let dir = outpath_base .parent() - .ok_or_else(|| anyhow::anyhow!("Could not get parent directory"))?; + .filter(|p| !p.as_os_str().is_empty()) // Ensure the parent is not empty + .map(|p| p.to_path_buf()) // Use the parent if it's valid + .or_else(|| get_current_directory().ok()) // Fallback to current directory if no valid parent + .ok_or_else(|| anyhow::anyhow!("Failed to determine a valid directory"))?; + + if !dir.exists() { + return Err(anyhow::anyhow!( + "Directory for output zipfile does not exist: {}", + dir + )); + } + let mut dir_entries = tokio::fs::read_dir(dir).await?; + // get absolute path for outpath + let current_dir = std::env::current_dir().context("Failed to retrieve current directory")?; + let outpath_absolute = outpath + .parent() + .filter(|parent| parent.as_std_path().exists()) + .map(|_| outpath.clone()) + .unwrap_or_else(|| { + PathBuf::from_path_buf(current_dir.join(outpath.as_std_path())) + .expect("Failed to convert to Utf8PathBuf") + }); + // Scan through all files in the directory while let Some(entry) = dir_entries.next_entry().await? { let entry_path: PathBuf = entry.path().try_into()?; + // Skip the `outpath` itself to loading, as we just overwrite this file for now (not append) + // TO DO: if we can append to the original output file, we can include this and then just add new signatures + + // get absolute path of entry for comparison with outpath_absolute + let current_dir = + std::env::current_dir().context("Failed to retrieve current directory")?; + let entry_absolute = entry_path + .parent() + .filter(|parent| parent.as_std_path().exists()) + .map(|_| entry_path.clone()) + .unwrap_or_else(|| { + PathBuf::from_path_buf(current_dir.join(entry_path.as_std_path())) + .expect("Failed to convert to Utf8PathBuf") + }); + + // For now, skip the `outpath` itself to avoid loading, since we will just overwrite it anyway. + if entry_absolute == outpath_absolute { + eprintln!("Skipping the original output file: {}", entry_absolute); + continue; + } if let Some(file_name) = entry_path.file_name() { // Check if the file matches the base zip file or any batched zip file (outpath.zip, outpath.1.zip, etc.) diff --git a/tests/test_gbsketch.py b/tests/test_gbsketch.py index 7c31b8e..d72ed04 100644 --- a/tests/test_gbsketch.py +++ b/tests/test_gbsketch.py @@ -836,3 +836,56 @@ def test_gbsketch_bad_param_str(runtmp, capfd): print(captured) assert "Failed to parse params string: Conflicting moltype settings in param string: 'DNA' and 'protein'" in captured.err + + +def test_gbsketch_overwrite(runtmp, capfd): + # test restart with complete + incomplete zipfile batches + acc_csv = get_test_data('acc.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + out1 = runtmp.output('simple.1.zip') + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + # ss3 = sourmash.load_one_signature(sig2, ksize=21) + ss4 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + # run the workflow once - write all to single output + runtmp.sourmash('scripts', 'gbsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000,abund", '-p', "protein,k=10,scaled=200") + assert os.path.exists(output) + captured = capfd.readouterr() + print(captured.err) + expected_siginfo = { + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + (ss2.name, ss2.md5sum(), ss2.minhash.moltype), + (ss4.name, ss4.md5sum(), ss4.minhash.moltype), + } + + idx = sourmash.load_file_as_index(output) + sigs = list(idx.signatures()) + all_siginfo = set() + for sig in sigs: + all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype)) + + assert all_siginfo == expected_siginfo + + # now, try running again - providing same output file. + runtmp.sourmash('scripts', 'gbsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000,abund", '-p', "protein,k=10,scaled=200") + + # check that sigs can still be read + assert os.path.exists(output) + assert not os.path.exists(out1) + idx = sourmash.load_file_as_index(output) + sigs = list(idx.signatures()) + all_siginfo = set() + for sig in sigs: + all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype)) + assert all_siginfo == expected_siginfo