diff --git a/src/rust/src/core/crypto.rs b/src/rust/src/core/crypto.rs index 5c2c6a10..491cb63c 100644 --- a/src/rust/src/core/crypto.rs +++ b/src/rust/src/core/crypto.rs @@ -16,12 +16,17 @@ use thiserror::Error; #[derive(Error, Debug, Eq, PartialEq)] pub enum Error { - #[error("no sender state could be found matching the provided data")] - NoMatchingSenderState, + #[error("no receiver state could be found matching the provided data")] + NoMatchingReceiverState, } const RATCHET_INFO_STRING: &[u8; 15] = b"RingRTC Ratchet"; -const MAX_SENDER_STATES_TO_RETAIN: usize = 5; +const MAX_RECEIVER_STATES_TO_RETAIN: usize = 5; +/// Maximum number of out of order frames to keep old ratchet keys for. +/// Accommodate up to 30 frames per second for 10 seconds worth of keys. +const MAX_OOO_FRAMES: u64 = 30 * 10; +/// Maximum number of out of order ratchets to keep old ratchet keys for. +const MAX_OOO_RATCHETS: u8 = 5; pub const MAC_SIZE_BYTES: usize = 16; // For some reason the linter doesn't detect this is required in the static assertions. @@ -66,24 +71,6 @@ impl SenderState { result } - fn advance_ratchet(&self, ratchet_counter_goal: RatchetCounter) -> Self { - let mut cur = self.ratchet_counter; - let mut secret = self.current_secret; - while cur != ratchet_counter_goal { - let secret_hkdf = Hkdf::::new(None, &secret); - secret_hkdf - .expand(RATCHET_INFO_STRING, &mut secret[..]) - .unwrap_or_else(|_| { - panic!( - "HKDF should work with output of length {}", - std::mem::size_of::() - ) - }); - cur = cur.wrapping_add(1); - } - SenderState::new(ratchet_counter_goal, secret) - } - fn mut_advance_ratchet(&mut self) { let secret_hkdf = Hkdf::::new(None, &self.current_secret[..]); secret_hkdf @@ -124,6 +111,94 @@ impl SenderState { } } +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +struct ReceiverState { + sender_state: SenderState, + ratchet_frame: FrameCounter, + old_secret: Secret, + old_ratchet_counter: RatchetCounter, +} + +impl ReceiverState { + fn new(ratchet_counter: RatchetCounter, secret: Secret) -> Self { + Self { + sender_state: SenderState::new(ratchet_counter, secret), + ratchet_frame: 0, + old_secret: secret, + old_ratchet_counter: ratchet_counter, + } + } + + fn try_advance_ratchet( + &self, + ratchet_counter_goal: RatchetCounter, + frame_counter: FrameCounter, + ) -> Self { + let mut cur; + let mut secret; + + if frame_counter > self.ratchet_frame { + cur = self.sender_state.ratchet_counter; + secret = self.sender_state.current_secret; + } else { + cur = self.old_ratchet_counter; + secret = self.old_secret; + } + + while cur != ratchet_counter_goal { + let secret_hkdf = Hkdf::::new(None, &secret); + secret_hkdf + .expand(RATCHET_INFO_STRING, &mut secret[..]) + .unwrap_or_else(|_| { + panic!( + "HKDF should work with output of length {}", + std::mem::size_of::() + ) + }); + cur = cur.wrapping_add(1); + } + let sender_state = SenderState::new(ratchet_counter_goal, secret); + if frame_counter.wrapping_sub(self.ratchet_frame) > MAX_OOO_FRAMES { + Self { + sender_state, + ratchet_frame: frame_counter, + old_secret: self.sender_state.current_secret, + old_ratchet_counter: self.sender_state.ratchet_counter, + } + } else { + Self { + sender_state, + ratchet_frame: frame_counter, + old_secret: self.old_secret, + old_ratchet_counter: self.old_ratchet_counter, + } + } + } + + /// Advance the old value, if needed, to limit retention of old secrets. + /// This is not done in try_advance_ratchet to avoid unnecessary work in + /// case the ratchet secret is not used. + fn limit_ooo(&mut self) { + while self + .sender_state + .ratchet_counter + .wrapping_sub(self.old_ratchet_counter) + > MAX_OOO_RATCHETS + { + let secret_hkdf = Hkdf::::new(None, &self.old_secret); + secret_hkdf + .expand(RATCHET_INFO_STRING, &mut self.old_secret[..]) + .unwrap_or_else(|_| { + panic!( + "HKDF should work with output of length {}", + std::mem::size_of::() + ) + }); + self.old_ratchet_counter = self.old_ratchet_counter.wrapping_add(1); + } + } +} + fn convert_frame_counter_to_iv(frame_counter: FrameCounter) -> Iv { const_assert!(size_of::() >= 8); let mut result = [0u8; size_of::()]; @@ -132,14 +207,14 @@ fn convert_frame_counter_to_iv(frame_counter: FrameCounter) -> Iv { } fn check_mac( - state: &SenderState, + state: &ReceiverState, frame_counter: FrameCounter, data: &[u8], associated_data: &[u8], mac: &Mac, ) -> bool { let iv = convert_frame_counter_to_iv(frame_counter); - let mut hmac = HmacSha256::new_from_slice(&state.current_hmac_key[..]) + let mut hmac = HmacSha256::new_from_slice(&state.sender_state.current_hmac_key[..]) .expect("HMAC can take key of any size"); hmac.update(&iv[..]); hmac.update(&len_as_u32_be_bytes(data)[..]); @@ -156,9 +231,9 @@ fn len_as_u32_be_bytes(slice: &[u8]) -> [u8; 4] { (slice.len() as u32).to_be_bytes() } -fn decrypt_internal(state: &SenderState, frame_counter: FrameCounter, data: &mut [u8]) { +fn decrypt_internal(state: &ReceiverState, frame_counter: FrameCounter, data: &mut [u8]) { let mut cipher = Aes256Ctr::new( - &state.current_aes_key.into(), + &state.sender_state.current_aes_key.into(), convert_frame_counter_to_iv(frame_counter)[..].into(), ); cipher.apply_keystream(data); @@ -167,7 +242,7 @@ fn decrypt_internal(state: &SenderState, frame_counter: FrameCounter, data: &mut pub struct Context { sender_state: SenderState, next_frame_counter: FrameCounter, - remote_sender_states_by_id: HashMap>, + remote_states_by_id: HashMap>, } impl Context { @@ -177,7 +252,7 @@ impl Context { Self { sender_state, next_frame_counter: 1, - remote_sender_states_by_id: HashMap::new(), + remote_states_by_id: HashMap::new(), } } @@ -223,11 +298,11 @@ impl Context { associated_data: &[u8], mac: &Mac, ) -> Result<(), Error> { - let states = self.get_mut_ref_sender_state_vec_by_id(sender_id); + let states = self.get_mut_ref_state_vec_by_id(sender_id); // try all states with matching ratchet counters first for state in states.iter() { - if state.ratchet_counter == ratchet_counter + if state.sender_state.ratchet_counter == ratchet_counter && check_mac(state, frame_counter, data, associated_data, mac) { decrypt_internal(state, frame_counter, data); @@ -237,15 +312,16 @@ impl Context { // before giving up, try more expensive repeated ratcheting of each state to match given ratchet counter for state in states.iter_mut() { - let try_state = state.advance_ratchet(ratchet_counter); + let mut try_state = state.try_advance_ratchet(ratchet_counter, frame_counter); if check_mac(&try_state, frame_counter, data, associated_data, mac) { + try_state.limit_ooo(); *state = try_state; decrypt_internal(state, frame_counter, data); return Ok(()); } } - Err(Error::NoMatchingSenderState) + Err(Error::NoMatchingReceiverState) } pub fn send_state(&self) -> (RatchetCounter, Secret) { @@ -269,9 +345,9 @@ impl Context { self.sender_state = SenderState::new(0, secret); } - /// Pushes a new SenderState onto the remote sender states map. + /// Pushes a new ReceiverState onto the remote sender states map. /// - /// A limited number of historical sender states are kept for each sender in order to handle + /// A limited number of historical receiver states are kept for each sender in order to handle /// frames delivered out of order with updated secrets. pub fn add_receive_secret( &mut self, @@ -279,17 +355,17 @@ impl Context { ratchet_counter: RatchetCounter, secret: Secret, ) { - let states = self.get_mut_ref_sender_state_vec_by_id(sender_id); - if states.len() == MAX_SENDER_STATES_TO_RETAIN { + let states = self.get_mut_ref_state_vec_by_id(sender_id); + if states.len() == MAX_RECEIVER_STATES_TO_RETAIN { states.pop(); } - states.insert(0, SenderState::new(ratchet_counter, secret)); + states.insert(0, ReceiverState::new(ratchet_counter, secret)); } - fn get_mut_ref_sender_state_vec_by_id(&mut self, sender_id: SenderId) -> &mut Vec { - self.remote_sender_states_by_id + fn get_mut_ref_state_vec_by_id(&mut self, sender_id: SenderId) -> &mut Vec { + self.remote_states_by_id .entry(sender_id) - .or_insert_with(|| Vec::with_capacity(MAX_SENDER_STATES_TO_RETAIN)) + .or_insert_with(|| Vec::with_capacity(MAX_RECEIVER_STATES_TO_RETAIN)) } } @@ -330,7 +406,7 @@ mod tests { let sender_id: SenderId = 42; ctx.add_receive_secret(sender_id, 0, send_secret); - let mut data = Vec::from(&plaintext[..]); + let mut data = plaintext.to_vec(); let associated_data = Vec::from("Can't touch this"); let mut mac = Mac::default(); let (ratchet_counter, frame_counter) = @@ -360,7 +436,7 @@ mod tests { let sender_id: SenderId = 8675309; ctx.add_receive_secret(sender_id, 0, send_secret); - let mut data = Vec::from(&plaintext[..]); + let mut data = plaintext.to_vec(); let associated_data = Vec::from("Can't touch this"); let mut mac = Mac::default(); let (ratchet_counter, frame_counter) = @@ -381,7 +457,7 @@ mod tests { let mut ctx2 = Context::new(random_secret(&mut rng)); ctx2.add_receive_secret(sender_id, ratchet_counter2, secret2); - let mut data = Vec::from(&plaintext[..]); + let mut data = plaintext.to_vec(); let associated_data = Vec::from("Can't touch this"); let mut mac = [0u8; MAC_SIZE_BYTES]; let (ratchet_counter, frame_counter) = @@ -397,7 +473,7 @@ mod tests { )?; assert_eq!(&plaintext[..], &data[..]); - let mut data = Vec::from(&plaintext[..]); + let mut data = plaintext.to_vec(); let (ratchet_counter, frame_counter) = ctx.encrypt(&mut data[..], &associated_data[..], &mut mac)?; assert_eq!(ratchet_counter2, ratchet_counter); @@ -423,7 +499,7 @@ mod tests { let sender_id: SenderId = 1392; ctx.add_receive_secret(sender_id, 0, send_secret); - let mut data = Vec::from(&plaintext[..]); + let mut data = plaintext.to_vec(); let associated_data = Vec::from("Can't touch this"); let mut mac = Mac::default(); let (ratchet_counter, frame_counter) = @@ -443,7 +519,7 @@ mod tests { let new_secret = random_secret(&mut rng); ctx.add_receive_secret(sender_id, 0, new_secret); - let mut data = Vec::from(&plaintext[..]); + let mut data = plaintext.to_vec(); let associated_data = Vec::from("Can't touch this"); let mut mac = Mac::default(); let (ratchet_counter, frame_counter) = @@ -462,7 +538,7 @@ mod tests { ctx.reset_send_ratchet(new_secret); - let mut data = Vec::from(&plaintext[..]); + let mut data = plaintext.to_vec(); let mut mac = Mac::default(); let (ratchet_counter, frame_counter) = ctx.encrypt(&mut data[..], &associated_data[..], &mut mac)?; @@ -490,7 +566,7 @@ mod tests { let sender_id: SenderId = 1492; ctx.add_receive_secret(sender_id, 0, send_secret); - let mut data = Vec::from(&plaintext[..]); + let mut data = plaintext.to_vec(); let mut associated_data = Vec::from("Can't touch this"); let mut mac = Mac::default(); let (ratchet_counter, frame_counter) = @@ -507,7 +583,7 @@ mod tests { &mac, ) .expect_err("decrypt should have returned an error"); - assert_eq!(err, Error::NoMatchingSenderState); + assert_eq!(err, Error::NoMatchingReceiverState); mac[0] = mac[0].wrapping_sub(1); ctx.decrypt( @@ -531,7 +607,7 @@ mod tests { &mac, ) .expect_err("decrypt should have returned an error"); - assert_eq!(err, Error::NoMatchingSenderState); + assert_eq!(err, Error::NoMatchingReceiverState); Ok(()) } @@ -539,12 +615,64 @@ mod tests { #[test] fn test_advance_ratchet_equal_sender_states() { let mut rng = StdRng::from_seed([0x34; 32]); - let sender_state = SenderState::new(0, random_secret(&mut rng)); + let secret = random_secret(&mut rng); + let sender_state = SenderState::new(0, secret); + let receiver_state = ReceiverState::new(0, secret); let mut sender_state_mut = sender_state; - let sender_state_adv = sender_state.advance_ratchet(5); + let receiver_state_adv = receiver_state.try_advance_ratchet(5, 0); for _ in 0..5 { sender_state_mut.mut_advance_ratchet(); } - assert_eq!(sender_state_adv, sender_state_mut); + assert_eq!(receiver_state_adv.sender_state, sender_state_mut); + } + + #[test] + fn test_ooo_ratchet() -> Result<(), Box> { + let plaintext = b"Whan Zephirus eek with his sweete breeth"; + let mut rng = StdRng::from_seed([0x2D; 32]); + let send_secret = random_secret(&mut rng); + let mut ctx = Context::new(send_secret); + let sender_id: SenderId = 8675309; + ctx.add_receive_secret(sender_id, 0, send_secret); + + let mut data1 = plaintext.to_vec(); + let associated_data1 = Vec::from("Can't touch this"); + let mut mac1 = Mac::default(); + let (ratchet_counter1, frame_counter1) = + ctx.encrypt(&mut data1[..], &associated_data1[..], &mut mac1)?; + assert_eq!(0, ratchet_counter1); + + let (ratchet_counter2, secret2) = ctx.advance_send_ratchet(); + // Another receiver that learned the secret after the ratchet was advanced + let mut ctx2 = Context::new(random_secret(&mut rng)); + ctx2.add_receive_secret(sender_id, ratchet_counter2, secret2); + + let mut data2 = plaintext.to_vec(); + let associated_data2 = Vec::from("Can't touch this"); + let mut mac2 = [0u8; MAC_SIZE_BYTES]; + let (ratchet_counter2, frame_counter2) = + ctx.encrypt(&mut data2[..], &associated_data2[..], &mut mac2)?; + assert_eq!(1, ratchet_counter2); + ctx.decrypt( + sender_id, + ratchet_counter2, + frame_counter2, + &mut data2[..], + &associated_data2[..], + &mac2, + )?; + assert_eq!(&plaintext[..], &data2[..]); + + // Now decrypt the first message, out of order + ctx.decrypt( + sender_id, + ratchet_counter1, + frame_counter1, + &mut data1[..], + &associated_data1[..], + &mac1, + )?; + assert_eq!(&plaintext[..], &data1[..]); + Ok(()) } }