diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 3b0030d9..91947347 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -30,6 +30,8 @@ pub struct FramedRead { max_header_list_size: usize, + max_continuation_frames: usize, + partial: Option, } @@ -41,6 +43,8 @@ struct Partial { /// Partial header payload buf: BytesMut, + + continuation_frames_count: usize, } #[derive(Debug)] @@ -51,10 +55,16 @@ enum Continuable { impl FramedRead { pub fn new(inner: InnerFramedRead) -> FramedRead { + let max_header_list_size = DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE; + let max_continuation_frames = calc_max_continuation_frames( + max_header_list_size, + inner.decoder().max_frame_length(), + ); FramedRead { inner, hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE), - max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE, + max_header_list_size, + max_continuation_frames, partial: None, } } @@ -68,7 +78,6 @@ impl FramedRead { } /// Returns the current max frame size setting - #[cfg(feature = "unstable")] #[inline] pub fn max_frame_size(&self) -> usize { self.inner.decoder().max_frame_length() @@ -80,13 +89,23 @@ impl FramedRead { #[inline] pub fn set_max_frame_size(&mut self, val: usize) { assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); - self.inner.decoder_mut().set_max_frame_length(val) + self.inner.decoder_mut().set_max_frame_length(val); + // Update max CONTINUATION frames too, since its based on this + self.max_continuation_frames = calc_max_continuation_frames( + self.max_header_list_size, + val, + ); } /// Update the max header list size setting. #[inline] pub fn set_max_header_list_size(&mut self, val: usize) { self.max_header_list_size = val; + // Update max CONTINUATION frames too, since its based on this + self.max_continuation_frames = calc_max_continuation_frames( + val, + self.max_frame_size(), + ); } /// Update the header table size setting. @@ -96,12 +115,22 @@ impl FramedRead { } } +fn calc_max_continuation_frames(header_max: usize, frame_max: usize) -> usize { + // At least this many frames needed to use max header list size + let min_frames_for_list = (header_max / frame_max).max(1); + // Some padding for imperfectly packed frames + // 25% without floats + let padding = min_frames_for_list >> 2; + min_frames_for_list.saturating_add(padding).max(5) +} + /// Decodes a frame. /// /// This method is intentionally de-generified and outlined because it is very large. fn decode_frame( hpack: &mut hpack::Decoder, max_header_list_size: usize, + max_continuation_frames: usize, partial_inout: &mut Option, mut bytes: BytesMut, ) -> Result, Error> { @@ -169,6 +198,7 @@ fn decode_frame( *partial_inout = Some(Partial { frame: Continuable::$frame(frame), buf: payload, + continuation_frames_count: 0, }); return Ok(None); @@ -259,6 +289,7 @@ fn decode_frame( Kind::Continuation => { let is_end_headers = (head.flag() & 0x4) == 0x4; + let mut partial = match partial_inout.take() { Some(partial) => partial, None => { @@ -273,6 +304,23 @@ fn decode_frame( return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } + // Check for CONTINUATION flood + if is_end_headers { + partial.continuation_frames_count = 0; + } else { + let cnt = partial.continuation_frames_count + 1; + if cnt > max_continuation_frames { + tracing::debug!("too_many_continuations, max = {}", max_continuation_frames); + return Err(Error::library_go_away_data( + Reason::ENHANCE_YOUR_CALM, + "too_many_continuations", + )); + } else { + partial.continuation_frames_count = cnt; + } + } + + // Extend the buf if partial.buf.is_empty() { partial.buf = bytes.split_off(frame::HEADER_LEN); @@ -354,9 +402,16 @@ where ref mut hpack, max_header_list_size, ref mut partial, + max_continuation_frames, .. } = *self; - if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? { + if let Some(frame) = decode_frame( + hpack, + max_header_list_size, + max_continuation_frames, + partial, + bytes, + )? { tracing::debug!(?frame, "received"); return Poll::Ready(Some(Ok(frame))); } diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index 7990c86a..e350bc6f 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -883,6 +883,53 @@ async fn too_big_headers_sends_reset_after_431_if_not_eos() { join(client, srv).await; } +#[tokio::test] +async fn too_many_continuation_frames_sends_goaway() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_frame_eq(settings, frames::settings().max_header_list_size(1024 * 24)); + + // the mock impl automatically splits into CONTINUATION frames if the + // headers are too big for one frame. So without a max header list size + // set, we'll send a bunch of headers that will eventually get nuked. + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .field("a".repeat(10_000), "b".repeat(10_000)) + .field("c".repeat(10_000), "d".repeat(10_000)) + .field("e".repeat(10_000), "f".repeat(10_000)) + .field("g".repeat(10_000), "h".repeat(10_000)) + .field("i".repeat(10_000), "j".repeat(10_000)) + .field("k".repeat(10_000), "l".repeat(10_000)) + .field("m".repeat(10_000), "n".repeat(10_000)) + .field("o".repeat(10_000), "p".repeat(10_000)) + .field("y".repeat(10_000), "z".repeat(10_000)), + ) + .await; + client.recv_frame(frames::go_away(0).calm().data("too_many_continuations")).await; + }; + + let srv = async move { + let mut srv = server::Builder::new() + // should mean ~3 continuation + .max_header_list_size(1024 * 24) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let err = srv.next().await.unwrap().expect_err("server"); + assert!(err.is_go_away()); + assert!(err.is_library()); + assert_eq!(err.reason(), Some(Reason::ENHANCE_YOUR_CALM)); + }; + + join(client, srv).await; +} + #[tokio::test] async fn pending_accept_recv_illegal_content_length_data() { h2_support::trace_init!();