diff --git a/crates/network-scheduler/Cargo.toml b/crates/network-scheduler/Cargo.toml index 8e54274..c17e4a9 100644 --- a/crates/network-scheduler/Cargo.toml +++ b/crates/network-scheduler/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "network-scheduler" -version = "1.0.27" +version = "1.1.0" edition = "2021" [dependencies] @@ -10,7 +10,11 @@ aws-config = { version = "1", features = ["behavior-version-latest"] } aws-sdk-s3 = "1" axum = { version = "0.7", features = ["json"] } base64 = "0.22.1" +bs58 = "0.5.1" +chrono = "0.4.38" clap = { version = "4", features = ["derive", "env"] } +crypto_box = "0.9.1" +curve25519-dalek = "4.1.3" dashmap = { version = "6", features = ["serde"] } derive-enum-from-into = "0.1" env_logger = "0.11" @@ -33,7 +37,7 @@ semver = "1" serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" serde-partial = "0.3" -serde_with = { version = "3", features = ["hex"] } +serde_with = { version = "3", features = ["base64", "hex"] } serde_yaml = "0.9" sha2 = "0.10.8" sha3 = "0.10" @@ -43,7 +47,7 @@ url = "2.5.0" sqd-contract-client = { workspace = true } sqd-messages = { workspace = true, features = ["semver"] } sqd-network-transport = { workspace = true, features = ["scheduler", "metrics"] } -chrono = "0.4.38" + [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.6" diff --git a/crates/network-scheduler/src/assignment.rs b/crates/network-scheduler/src/assignment.rs index a655d4a..4c45c0f 100644 --- a/crates/network-scheduler/src/assignment.rs +++ b/crates/network-scheduler/src/assignment.rs @@ -1,6 +1,18 @@ +use core::str; use std::collections::HashMap; +use crypto_box::{ + aead::{Aead, AeadCore, OsRng}, + SalsaBox, PublicKey, SecretKey +}; use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde_with::serde_as; +use serde_with::base64::Base64; +use sha2::Sha512; +use sha2::Digest; +use sha3::digest::generic_array::GenericArray; +use curve25519_dalek::edwards::CompressedEdwardsY; use crate::signature::timed_hmac_now; @@ -23,11 +35,23 @@ pub struct Dataset { #[derive(Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] -struct EncryptedHeaders { +struct Headers { worker_id: String, worker_signature: String, } +#[serde_as] +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +struct EncryptedHeaders { + #[serde_as(as = "Base64")] + identity: Vec, + #[serde_as(as = "Base64")] + nonce: Vec, + #[serde_as(as = "Base64")] + ciphertext: Vec, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct WorkerAssignment { @@ -118,18 +142,26 @@ impl Assignment { Some(result) } - pub fn headers_for_peer_id(&self, peer_id: String) -> Option> { - let local_assignment = match self.worker_assignments.get(&peer_id) { - Some(worker_assignment) => worker_assignment, - None => { - return None - } + pub fn headers_for_peer_id(&self, peer_id: String, secret_key: Vec) -> Option> { + let local_assignment = self.worker_assignments.get(&peer_id)?; + let EncryptedHeaders {identity, nonce, ciphertext,} = &local_assignment.encrypted_headers; + let Ok(temporary_public_key) = PublicKey::from_slice(identity.as_slice()) else { + return None }; - let headers = match serde_json::to_value(&local_assignment.encrypted_headers) { - Ok(v) => v, - Err(_) => { - return None; - } + let big_slice = Sha512::default().chain_update(secret_key).finalize(); + let Ok(worker_secret_key) = SecretKey::from_slice(&big_slice[00..32]) else { + return None + }; + let shared_box = SalsaBox::new(&temporary_public_key, &worker_secret_key); + let generic_nonce = GenericArray::clone_from_slice(nonce); + let Ok(decrypted_plaintext) = shared_box.decrypt(&generic_nonce, &ciphertext[..]) else { + return None + }; + let Ok(plaintext_headers) = std::str::from_utf8(&decrypted_plaintext) else { + return None; + }; + let Ok(headers) = serde_json::from_str::(plaintext_headers) else { + return None; }; let mut result: HashMap = Default::default(); for (k,v) in headers.as_object().unwrap() { @@ -154,15 +186,58 @@ impl Assignment { } pub fn regenerate_headers(&mut self, cloudflare_storage_secret: String) { + let temporary_secret_key = SecretKey::generate(&mut OsRng); + let temporary_public_key_bytes = *temporary_secret_key.public_key().as_bytes(); + for (worker_id, worker_assignment) in &mut self.worker_assignments { let worker_signature = timed_hmac_now( worker_id, &cloudflare_storage_secret, ); - worker_assignment.encrypted_headers = EncryptedHeaders { + + let headers = Headers { worker_id: worker_id.to_string(), worker_signature, - } + }; + + let pub_key_edvards_bytes = &bs58::decode(worker_id).into_vec().unwrap()[6..]; + let public_edvards_compressed = CompressedEdwardsY::from_slice(pub_key_edvards_bytes).unwrap(); + let public_edvards = public_edvards_compressed.decompress().unwrap(); + let public_montgomery = public_edvards.to_montgomery(); + let worker_public_key = PublicKey::from(public_montgomery); + + let shared_box = SalsaBox::new(&worker_public_key, &temporary_secret_key); + let nonce = SalsaBox::generate_nonce(&mut OsRng); + let plaintext = serde_json::to_vec(&headers).unwrap(); + let ciphertext = shared_box.encrypt(&nonce, &plaintext[..]).unwrap(); + + + worker_assignment.encrypted_headers = EncryptedHeaders { + identity: temporary_public_key_bytes.to_vec(), + nonce: nonce.to_vec(), + ciphertext, + }; } } +} + +#[cfg(test)] +mod tests { + use sqd_network_transport::Keypair; + + use super::*; + + #[test] + fn it_works() { + let mut assignment: Assignment = Default::default(); + let keypair = Keypair::generate_ed25519(); + let peer_id = keypair.public().to_peer_id().to_base58(); + let private_key = keypair.try_into_ed25519().unwrap().secret(); + + assignment.insert_assignment(peer_id.clone(), "Ok".to_owned(), Default::default()); + assignment.regenerate_headers("SUPERSECRET".to_owned()); + let headers = assignment.headers_for_peer_id(peer_id.clone(), private_key.as_ref().to_vec()).unwrap(); + let decrypted_id = headers.get("worker-id").unwrap(); + assert_eq!(peer_id, decrypted_id.to_owned()); + } } \ No newline at end of file diff --git a/crates/network-scheduler/src/server.rs b/crates/network-scheduler/src/server.rs index ebbe039..4b42011 100644 --- a/crates/network-scheduler/src/server.rs +++ b/crates/network-scheduler/src/server.rs @@ -12,7 +12,6 @@ use tokio::join; use tokio::signal::unix::{signal, SignalKind}; use tokio::sync::mpsc::Receiver; use tokio::time::Instant; -use base64::{engine::general_purpose::STANDARD as base64, Engine}; use sqd_messages::signatures::msg_hash; use sqd_messages::{Pong, RangeSet}; @@ -337,7 +336,7 @@ fn build_assignment( files.insert(filename.clone(), filename); } let dataset_str = chunk.dataset_id; - let dataset_id = base64.encode(dataset_str); + let dataset_id = dataset_str; let size_bytes = chunk.size_bytes; let chunk = Chunk { id: chunk_str.clone(),