diff --git a/Cargo.lock b/Cargo.lock index 5fd77dc7..00ac2952 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1069,7 +1069,7 @@ dependencies = [ "bitflags 2.4.2", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -2881,7 +2881,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", + "socket2 0.5.5", "tokio", "tower-service", "tracing", @@ -3624,6 +3624,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.7.1" @@ -4275,6 +4281,7 @@ dependencies = [ "futures", "indicatif", "libftd2xx", + "md5", "nusb", "orb-build-info", "orb-security-utils", @@ -4897,7 +4904,7 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.10.5", + "itertools 0.12.1", "log", "multimap", "once_cell", @@ -4917,7 +4924,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.77", @@ -4930,7 +4937,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.77", diff --git a/hil/Cargo.toml b/hil/Cargo.toml index 6b04f5b8..24156e14 100644 --- a/hil/Cargo.toml +++ b/hil/Cargo.toml @@ -26,6 +26,7 @@ nusb.workspace = true orb-build-info.path = "../build-info" orb-security-utils = { workspace = true, features = ["reqwest"] } reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } +md5 = "0.7.0" secrecy.workspace = true tempfile = "3" thiserror.workspace = true diff --git a/hil/src/commands/flash.rs b/hil/src/commands/flash.rs index c82f939e..ab3d2bc7 100644 --- a/hil/src/commands/flash.rs +++ b/hil/src/commands/flash.rs @@ -27,8 +27,11 @@ pub struct Flash { #[arg(long)] slow: bool, /// If this flag is given, overwites any existing files when downloading the rts. - #[arg(long)] + #[arg(long, conflicts_with = "reuse_existing")] overwrite_existing: bool, + /// If this flag is given, reuses rts if already exists on path. + #[arg(long, conflicts_with = "overwite_existing")] + reuse_existing: bool, } impl Flash { @@ -36,13 +39,15 @@ impl Flash { let args = self; let existing_file_behavior = if args.overwrite_existing { ExistingFileBehavior::Overwrite + } else if args.reuse_existing { + ExistingFileBehavior::Reuse } else { ExistingFileBehavior::Abort }; - ensure!( - crate::boot::is_recovery_mode_detected()?, - "orb must be in recovery mode to flash. Try running `orb-hil reboot -r`" - ); + // ensure!( + // crate::boot::is_recovery_mode_detected()?, + // "orb must be in recovery mode to flash. Try running `orb-hil reboot -r`" + // ); let rts_path = if let Some(ref s3_url) = args.s3_url { if args.rts_path.is_some() { bail!("both rts_path and s3_url were specified - only provide one or the other"); diff --git a/hil/src/download_s3.rs b/hil/src/download_s3.rs index 9586dd90..3fc4cb01 100644 --- a/hil/src/download_s3.rs +++ b/hil/src/download_s3.rs @@ -1,11 +1,13 @@ use std::{ fs::File, - io::IsTerminal, + io::{BufReader, IsTerminal, Read}, ops::RangeInclusive, os::unix::fs::FileExt, str::FromStr, - sync::atomic::{AtomicU64, Ordering}, - sync::Arc, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, time::Duration, }; @@ -22,6 +24,9 @@ use color_eyre::{ eyre::{ensure, ContextCompat, OptionExt, WrapErr}, Result, Section, }; +// use chksum_hash_md5 as md5; +use md5::Context; + use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio::time::timeout; @@ -37,6 +42,8 @@ pub enum ExistingFileBehavior { Overwrite, /// If a file exists, abort Abort, + /// If a file exists, reuse it + Reuse, } struct ContentRange(RangeInclusive); @@ -49,14 +56,40 @@ impl std::fmt::Display for ContentRange { } } +fn calculate_checksum(out_path: &Utf8Path, part_size: usize) -> Result { + let file = File::open(out_path) + .with_context(|| format!("Failed to open file: {}", out_path))?; + let mut reader = BufReader::new(file); + let mut part_digests = Vec::new(); + let mut buffer = vec![0u8; part_size]; + + loop { + let bytes_read = reader + .read(&mut buffer) + .with_context(|| format!("Failed to read from file: {}", out_path))?; + if bytes_read == 0 { + break; + } + let mut hasher = md5::Context::new(); + hasher.consume(&buffer[..bytes_read]); + part_digests.push(hasher.compute()); + } + + let mut final_hasher = md5::Context::new(); + for digest in &part_digests { + final_hasher.consume(digest.as_ref()); + } + let final_digest = final_hasher.compute(); + let etag = format!("{:x}-{}", final_digest, part_digests.len()); + + Ok(etag) +} + pub async fn download_url( url: &str, out_path: &Utf8Path, existing_file_behavior: ExistingFileBehavior, ) -> Result<()> { - if existing_file_behavior == ExistingFileBehavior::Abort { - ensure!(!out_path.exists(), "{out_path:?} already exists!"); - } let parent_dir = out_path .parent() .expect("please provide the path to a file") @@ -69,6 +102,7 @@ pub async fn download_url( let start_time = std::time::Instant::now(); let client = client().await?; + let head_resp = client .head_object() .bucket(&s3_parts.bucket) @@ -77,6 +111,72 @@ pub async fn download_url( .await .wrap_err("failed to make aws head_object request")?; + let path_exists = out_path.exists(); + + if path_exists && existing_file_behavior == ExistingFileBehavior::Reuse { + info!("checking integrity of {out_path}"); + + let remote_checksum = head_resp + .e_tag() + .context("ETag not found in head response")? + .replace("\"", ""); + + let content_length = head_resp + .content_length() + .context("Content-Length not found in HEAD response")?; + + let part_number_str = remote_checksum.split('-').nth(1).with_context(|| { + format!("No part number found in ETag '{}'", remote_checksum) + })?; + + let part_number = part_number_str.parse::().with_context(|| { + format!( + "Failed to parse part number '{}' from ETag", + part_number_str + ) + })?; + println!("Calculated part_size: {}", part_number); + println!("content_length: {}", content_length); + let part_size_i64 = content_length + .checked_add(part_number + 1) + .and_then(|sum| sum.checked_div(part_number)) + .context("Division by part_number resulted in overflow or was invalid")?; + + let part_size = usize::try_from(part_size_i64).with_context(|| { + format!( + "Calculated part_size {} exceeds u8::MAX (255). Cannot proceed.", + part_size_i64 + ) + })?; + ensure!(part_number > 0, "Part number in ETag cannot be zero"); + + println!("Calculated part_size: {}", part_size); + println!("Calculated part_size: {}", 8192 * 1024); + let existing_checksum = calculate_checksum(out_path, 8192 * 1024) + .with_context(|| { + format!("Failed to calculate checksum for '{}'", out_path) + })?; + + ensure!( + remote_checksum == existing_checksum, + "locally stored checksum of {out_path} file != remote checksum. {existing_checksum} != {remote_checksum}" + ); + info!("checksums match, reusing {out_path}"); + } + + match existing_file_behavior { + ExistingFileBehavior::Abort => { + ensure!(!path_exists, "{out_path:?} already exists!"); + } + ExistingFileBehavior::Reuse => { + if path_exists { + info!("{out_path:?} already exists, reusing it"); + return Ok(()); + } + } + ExistingFileBehavior::Overwrite => {} + } + let bytes_to_download = head_resp.content_length().unwrap(); let bytes_to_download: u64 = bytes_to_download