From 2c0b38d9e2d3e9a9b8860df096f1d10c16040aa0 Mon Sep 17 00:00:00 2001 From: Phoenix Kahlo Date: Tue, 19 Nov 2024 21:39:26 -0600 Subject: [PATCH] Allow client to use NEW_TOKEN frames When a client receives a token from a NEW_TOKEN frame, it submits it to a ValidationTokenStore object for storage. When an endpoint connects to a server, it queries the ValidationTokenStore object for a token applicable to the server name, and uses it if one is retrieved. --- quinn-proto/src/config.rs | 25 +- quinn-proto/src/connection/mod.rs | 28 +- quinn-proto/src/connection/packet_builder.rs | 2 +- quinn-proto/src/endpoint.rs | 11 +- quinn-proto/src/lib.rs | 3 + quinn-proto/src/validation_token_store.rs | 257 +++++++++++++++++++ quinn/src/lib.rs | 4 +- 7 files changed, 317 insertions(+), 13 deletions(-) create mode 100644 quinn-proto/src/validation_token_store.rs diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index a15448518..486cd285a 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -20,8 +20,9 @@ use crate::{ congestion, crypto::{self, HandshakeTokenKey, HmacKey}, shared::ConnectionId, - Duration, RandomConnectionIdGenerator, TokenLog, VarInt, VarIntBoundsExceeded, - DEFAULT_SUPPORTED_VERSIONS, INITIAL_MTU, MAX_CID_SIZE, MAX_UDP_PAYLOAD, + Duration, RandomConnectionIdGenerator, TokenLog, ValidationTokenMemoryCache, + ValidationTokenStore, VarInt, VarIntBoundsExceeded, DEFAULT_SUPPORTED_VERSIONS, INITIAL_MTU, + MAX_CID_SIZE, MAX_UDP_PAYLOAD, }; /// Parameters governing the core QUIC state machine @@ -1061,6 +1062,9 @@ pub struct ClientConfig { /// Cryptographic configuration to use pub(crate) crypto: Arc, + /// Validation token store to use + pub(crate) validation_token_store: Option>, + /// Provider that populates the destination connection ID of Initial Packets pub(crate) initial_dst_cid_provider: Arc ConnectionId + Send + Sync>, @@ -1074,6 +1078,7 @@ impl ClientConfig { Self { transport: Default::default(), crypto, + validation_token_store: Some(Arc::new(ValidationTokenMemoryCache::<2>::default())), initial_dst_cid_provider: Arc::new(|| { RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid() }), @@ -1103,6 +1108,21 @@ impl ClientConfig { self } + /// Set a custom [`ValidationTokenStore`] + /// + /// Defaults to an in-memory store limited to 256 servers and 2 tokens per server. This default + /// is chosen to complement `rustls`'s default + /// [`ClientSessionStore`][rustls::client::ClientSessionStore]. + /// + /// Setting to `None` disables the use of tokens from NEW_TOKEN frames as a client. + pub fn validation_token_store( + &mut self, + validation_token_store: Option>, + ) -> &mut Self { + self.validation_token_store = validation_token_store; + self + } + /// Set the QUIC version to use pub fn version(&mut self, version: u32) -> &mut Self { self.version = version; @@ -1134,6 +1154,7 @@ impl fmt::Debug for ClientConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("ClientConfig") .field("transport", &self.transport) + // validation_token_store not debug // crypto not debug .field("version", &self.version) .finish_non_exhaustive() diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 3714866d9..8fed421f3 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -33,8 +33,8 @@ use crate::{ token::{ResetToken, Token, TokenInner}, transport_parameters::TransportParameters, Dir, Duration, EndpointConfig, Frame, Instant, Side, StreamId, Transmit, TransportError, - TransportErrorCode, VarInt, INITIAL_MTU, MAX_CID_SIZE, MAX_STREAM_COUNT, MIN_INITIAL_SIZE, - TIMER_GRANULARITY, + TransportErrorCode, ValidationTokenStore, VarInt, INITIAL_MTU, MAX_CID_SIZE, MAX_STREAM_COUNT, + MIN_INITIAL_SIZE, TIMER_GRANULARITY, }; mod ack_frequency; @@ -194,7 +194,7 @@ pub struct Connection { error: Option, /// Sent in every outgoing Initial packet. Always empty for servers and after Initial keys are /// discarded. - retry_token: Bytes, + token: Bytes, /// Identifies Data-space packet numbers to skip. Not used in earlier spaces. packet_number_filter: PacketNumberFilter, @@ -227,6 +227,9 @@ pub struct Connection { /// no outgoing application data. app_limited: bool, + validation_token_store: Option>, + server_name: Option, + streams: StreamsState, /// Surplus remote CIDs for future use on new paths rem_cids: CidQueue, @@ -258,6 +261,8 @@ impl Connection { allow_mtud: bool, rng_seed: [u8; 32], path_validated: bool, + validation_token_store: Option>, + server_name: Option, ) -> Self { let side = if server_config.is_some() { Side::Server @@ -274,6 +279,10 @@ impl Connection { client_hello: None, }); let mut rng = StdRng::from_seed(rng_seed); + let token = validation_token_store + .as_ref() + .and_then(|store| store.take(server_name.as_ref().unwrap())) + .unwrap_or_default(); let mut this = Self { endpoint_config, server_config, @@ -324,7 +333,7 @@ impl Connection { timers: TimerTable::default(), authentication_failures: 0, error: None, - retry_token: Bytes::new(), + token, #[cfg(test)] packet_number_filter: match config.deterministic_packet_numbers { false => PacketNumberFilter::new(&mut rng), @@ -346,6 +355,9 @@ impl Connection { receiving_ecn: false, total_authed_packets: 0, + validation_token_store, + server_name, + streams: StreamsState::new( side, config.max_concurrent_uni_streams, @@ -2105,7 +2117,7 @@ impl Connection { trace!("discarding {:?} keys", space_id); if space_id == SpaceId::Initial { // No longer needed - self.retry_token = Bytes::new(); + self.token = Bytes::new(); } let space = &mut self.spaces[space_id]; space.crypto = None; @@ -2424,7 +2436,7 @@ impl Connection { self.streams.retransmit_all_for_0rtt(); let token_len = packet.payload.len() - 16; - self.retry_token = packet.payload.freeze().split_to(token_len); + self.token = packet.payload.freeze().split_to(token_len); self.state = State::Handshake(state::Handshake { expected_token: Bytes::new(), rem_cid_set: false, @@ -2866,7 +2878,9 @@ impl Connection { return Err(TransportError::FRAME_ENCODING_ERROR("empty token")); } trace!("got new token"); - // TODO: Cache, or perhaps forward to user? + if let Some(store) = self.validation_token_store.as_ref() { + store.store(self.server_name.as_ref().unwrap(), token); + } } Frame::Datagram(datagram) => { if self diff --git a/quinn-proto/src/connection/packet_builder.rs b/quinn-proto/src/connection/packet_builder.rs index 868a8c7ca..71079037e 100644 --- a/quinn-proto/src/connection/packet_builder.rs +++ b/quinn-proto/src/connection/packet_builder.rs @@ -113,7 +113,7 @@ impl PacketBuilder { SpaceId::Initial => Header::Initial(InitialHeader { src_cid: conn.handshake_cid, dst_cid, - token: conn.retry_token.clone(), + token: conn.token.clone(), number, version, }), diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 2cc8576eb..9f58a3be2 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -32,7 +32,8 @@ use crate::{ token::{TokenDecodeError, TokenInner}, transport_parameters::{PreferredAddress, TransportParameters}, Duration, Instant, ResetToken, Side, SystemTime, Token, Transmit, TransportConfig, - TransportError, INITIAL_MTU, MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, + TransportError, ValidationTokenStore, INITIAL_MTU, MAX_CID_SIZE, MIN_INITIAL_SIZE, + RESET_TOKEN_SIZE, }; /// The main entry point to the library @@ -432,6 +433,8 @@ impl Endpoint { None, config.transport, true, + config.validation_token_store, + Some(server_name.into()), ); Ok((ch, conn)) } @@ -685,6 +688,8 @@ impl Endpoint { Some(server_config), transport_config, remote_address_validated, + None, + None, ); self.index.insert_initial(dst_cid, ch); @@ -851,6 +856,8 @@ impl Endpoint { server_config: Option>, transport_config: Arc, path_validated: bool, + new_token_store: Option>, + server_name: Option, ) -> Connection { let mut rng_seed = [0; 32]; self.rng.fill_bytes(&mut rng_seed); @@ -875,6 +882,8 @@ impl Endpoint { self.allow_mtud, rng_seed, path_validated, + new_token_store, + server_name, ); let mut cids_issued = 0; diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 9625f6458..8f2c6acd5 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -96,6 +96,9 @@ mod bloom_token_log; #[cfg(feature = "fastbloom")] pub use bloom_token_log::BloomTokenLog; +mod validation_token_store; +pub use validation_token_store::{ValidationTokenMemoryCache, ValidationTokenStore}; + #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; diff --git a/quinn-proto/src/validation_token_store.rs b/quinn-proto/src/validation_token_store.rs new file mode 100644 index 000000000..61f749dbb --- /dev/null +++ b/quinn-proto/src/validation_token_store.rs @@ -0,0 +1,257 @@ +//! Storing tokens sent from servers in NEW_TOKEN frames and using them in subsequent connections + +use bytes::Bytes; +use slab::Slab; +use std::{ + collections::{hash_map, HashMap}, + mem::take, + sync::{Arc, Mutex}, +}; + +/// Responsible for storing address validation tokens received from servers and retrieving them for +/// use in subsequent connections +pub trait ValidationTokenStore: Send + Sync { + /// Potentially store a token for later one-time use + /// + /// Called when a NEW_TOKEN frame is received from the server. + fn store(&self, server_name: &str, token: Bytes); + + /// Try to find and take a token that was stored with the given server name + /// + /// The same token must never be returned from `take` twice, as doing so can be used to + /// de-anonymize a client's traffic. + /// + /// Called when trying to connect to a server. It is always ok for this to return `None`. + fn take(&self, server_name: &str) -> Option; +} + +/// `ValidationTokenMemoryCache` implementation that stores up to `N` tokens per server name for up +/// to a limited number of server names, in-memory +pub struct ValidationTokenMemoryCache(Mutex>); + +impl ValidationTokenMemoryCache { + /// Construct empty + pub fn new(max_server_names: usize) -> Self { + Self(Mutex::new(State::new(max_server_names))) + } +} + +impl ValidationTokenStore for ValidationTokenMemoryCache { + fn store(&self, server_name: &str, token: Bytes) { + self.0.lock().unwrap().store(server_name, token) + } + + fn take(&self, server_name: &str) -> Option { + self.0.lock().unwrap().take(server_name) + } +} + +/// Defaults to a size limit of 256 +impl Default for ValidationTokenMemoryCache { + fn default() -> Self { + Self::new(256) + } +} + +/// Lockable inner state of `ValidationTokenMemoryCache`. +#[derive(Debug)] +struct State { + max_server_names: usize, + // linked hash table structure + lookup: HashMap, usize>, + entries: Slab>, + oldest_newest: Option<(usize, usize)>, +} + +/// Cache entry within `State`. +#[derive(Debug)] +struct CacheEntry { + server_name: Arc, + older: Option, + newer: Option, + tokens: Queue, +} + +impl State { + fn new(max_server_names: usize) -> Self { + assert!(max_server_names > 0, "size limit cannot be 0"); + Self { + max_server_names, + lookup: HashMap::new(), + entries: Slab::new(), + oldest_newest: None, + } + } + + /// Unlink an entry's neighbors from it + fn unlink( + idx: usize, + entries: &mut Slab>, + oldest_newest: &mut Option<(usize, usize)>, + ) { + if let Some(older) = entries[idx].older { + entries[older].newer = entries[idx].newer; + } else { + // unwrap safety: entries[idx] exists, therefore oldest_newest is some + *oldest_newest = entries[idx] + .newer + .map(|newer| (oldest_newest.unwrap().0, newer)); + } + if let Some(newer) = entries[idx].newer { + entries[newer].older = entries[idx].older; + } else { + // unwrap safety: oldest_newest is none iff entries[idx] was the only entry. + // if entries[idx].older is some, entries[idx] was not the only entry + // therefore oldest_newest is some. + *oldest_newest = entries[idx] + .older + .map(|older| (older, oldest_newest.unwrap().1)); + } + } + + /// Link an entry as the most recently used entry + /// + /// Assumes any pre-existing neighbors are already unlinked. + fn link( + idx: usize, + entries: &mut Slab>, + oldest_newest: &mut Option<(usize, usize)>, + ) { + entries[idx].newer = None; + entries[idx].older = oldest_newest.map(|(_, newest)| newest); + if let &mut Some((_, ref mut newest)) = oldest_newest { + *newest = idx; + } else { + *oldest_newest = Some((idx, idx)); + } + } + + fn store(&mut self, server_name: &str, token: Bytes) { + let server_name = Arc::::from(server_name); + let idx = match self.lookup.entry(server_name.clone()) { + hash_map::Entry::Occupied(hmap_entry) => { + // key already exists, add the new token to its token stack + let entry = &mut self.entries[*hmap_entry.get()]; + entry.tokens.push(token); + + // unlink the entry and set it up to be linked as the most recently used + Self::unlink( + *hmap_entry.get(), + &mut self.entries, + &mut self.oldest_newest, + ); + *hmap_entry.get() + } + hash_map::Entry::Vacant(hmap_entry) => { + // key does not yet exist, create a new one, evicting the oldest if necessary + let removed_key = if self.entries.len() >= self.max_server_names { + // unwrap safety: max_server_names is > 0, so there's at least one entry, so + // oldest_newest is some + let oldest = self.oldest_newest.unwrap().0; + Self::unlink(oldest, &mut self.entries, &mut self.oldest_newest); + Some(self.entries.remove(oldest).server_name) + } else { + None + }; + + let mut tokens = Queue::new(); + tokens.push(token); + let idx = self.entries.insert(CacheEntry { + server_name, + // we'll link these after the fact + older: None, + newer: None, + tokens, + }); + hmap_entry.insert(idx); + + // for borrowing reasons, we must defer removing the evicted hmap entry + if let Some(removed_key) = removed_key { + let removed = self.lookup.remove(&removed_key); + debug_assert!(removed.is_some()); + } + + idx + } + }; + + // link it as the newest entry + Self::link(idx, &mut self.entries, &mut self.oldest_newest); + } + + fn take(&mut self, server_name: &str) -> Option { + if let hash_map::Entry::Occupied(hmap_entry) = self.lookup.entry(server_name.into()) { + let entry = &mut self.entries[*hmap_entry.get()]; + // pop from entry's token stack + let token = entry.tokens.pop(); + if entry.tokens.len > 1 { + // re-link entry as most recently used + Self::unlink( + *hmap_entry.get(), + &mut self.entries, + &mut self.oldest_newest, + ); + Self::link( + *hmap_entry.get(), + &mut self.entries, + &mut self.oldest_newest, + ); + } else { + // token stack emptied, remove entry + Self::unlink( + *hmap_entry.get(), + &mut self.entries, + &mut self.oldest_newest, + ); + self.entries.remove(*hmap_entry.get()); + hmap_entry.remove(); + } + Some(token) + } else { + None + } + } +} + +/// In-place deque queue of up to `N` `Bytes` +#[derive(Debug)] +struct Queue { + elems: [Bytes; N], + // if len > 0, front is elems[start] + // invariant: start < N + start: usize, + // if len > 0, back is elems[(start + len - 1) % N] + len: usize, +} + +impl Queue { + /// Construct empty + fn new() -> Self { + const EMPTY_BYTES: Bytes = Bytes::new(); + Self { + elems: [EMPTY_BYTES; N], + start: 0, + len: 0, + } + } + + /// Push to back, popping from front first if already at capacity + fn push(&mut self, elem: Bytes) { + self.elems[(self.start + self.len) % N] = elem; + if self.len < N { + self.len += 1; + } else { + self.start += 1; + self.start %= N; + } + } + + /// Pop from front, panicking if empty + fn pop(&mut self) -> Bytes { + self.len = self + .len + .checked_sub(1) + .expect("ValidationTokenMemoryCache popped from empty Queue, this is a bug!"); + take(&mut self.elems[(self.start + self.len) % N]) + } +} diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 0019b0352..81ebf86cb 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -68,8 +68,8 @@ pub use proto::{ ConfigError, ConnectError, ConnectionClose, ConnectionError, ConnectionId, ConnectionIdGenerator, ConnectionStats, Dir, EcnCodepoint, EndpointConfig, FrameStats, FrameType, IdleTimeout, MtuDiscoveryConfig, PathStats, ServerConfig, Side, StreamId, TokenLog, - TokenReuseError, Transmit, TransportConfig, TransportErrorCode, UdpStats, VarInt, - VarIntBoundsExceeded, Written, + TokenReuseError, Transmit, TransportConfig, TransportErrorCode, UdpStats, + ValidationTokenMemoryCache, ValidationTokenStore, VarInt, VarIntBoundsExceeded, Written, }; #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] pub use rustls;