From 1fb5e204eef9700d6070b2ccc4607ed4b8426de5 Mon Sep 17 00:00:00 2001 From: wcampbell Date: Sat, 30 Nov 2024 23:48:20 -0500 Subject: [PATCH] reader: Remove bits_read for end * Remove bits_read from end, as it's not really read, and we need to count the read when the leftover is read later. * Fix missing bits_read increase on previous read * Improve test coverage of reader by adding tests, and add more asserts --- src/reader.rs | 68 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/src/reader.rs b/src/reader.rs index 9a93967a..68ee7fbd 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -140,7 +140,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> { /// Return true if we are at the end of a reader and there are no cached bits in the reader. /// Since this uses [Read] internally, this will return true when [Read] returns [ErrorKind::UnexpectedEof]. /// - /// The byte that was read will be internally buffered + /// The byte that was read will be internally buffered and will *not* be included in the `bits_read` count. #[inline] pub fn end(&mut self) -> bool { if self.leftover.is_some() { @@ -160,13 +160,14 @@ impl<'a, R: Read + Seek> Reader<'a, R> { #[cfg(feature = "logging")] log::trace!("not end: read {:02x?}", &buf); - self.bits_read += 8; self.leftover = Some(Leftover::Byte(buf[0])); false } } /// Used at the beginning of `from_reader`. + /// + /// This will increment `bits_read`. // TODO: maybe send into read_bytes() if amt >= 8 #[inline] pub fn skip_bits(&mut self, amt: usize) -> Result<(), DekuError> { @@ -186,7 +187,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> { .map_err(|e| DekuError::Io(e.kind()))?; } - // Unlike normal seek not couting as bits_read, this one does + // Unlike normal seek not counting as bits_read, this one does // to keep from_bytes returns self.bits_read += bytes_amt * 8; @@ -374,6 +375,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> { #[cfg(feature = "logging")] log::trace!("read_bytes_const_leftover: returning {:02x?}", &buf); + self.bits_read += amt * 8; return Ok(ReaderRet::Bytes); } let buf_len = buf.len(); @@ -389,7 +391,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> { } return Err(DekuError::Io(e.kind())); } - self.bits_read += remaining * 8; + self.bits_read += amt * 8; #[cfg(feature = "logging")] log::trace!("read_bytes_leftover: returning {:02x?}", &buf); @@ -462,6 +464,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> { if remaining == 0 { #[cfg(feature = "logging")] log::trace!("read_bytes_const_leftover: returning {:02x?}", &buf); + self.bits_read += N * 8; return Ok(ReaderRet::Bytes); } @@ -478,7 +481,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> { } return Err(DekuError::Io(e.kind())); } - self.bits_read += remaining * 8; + self.bits_read += N * 8; #[cfg(feature = "logging")] log::trace!("read_bytes_const_leftover: returning {:02x?}", &buf); @@ -503,15 +506,26 @@ mod tests { let mut buf = [0; 2]; let _ = reader.read_bytes_const::<2>(&mut buf).unwrap(); assert!(reader.end()); + assert_eq!(reader.bits_read, 8 * 2); let input = hex!("aa"); let mut cursor = Cursor::new(input); let mut reader = Reader::new(&mut cursor); assert!(!reader.end()); - let _ = reader.read_bits(4); + let _ = reader.read_bits(4).unwrap(); + assert!(!reader.end()); + let _ = reader.read_bits(4).unwrap(); + assert!(reader.end()); + assert_eq!(reader.bits_read, 8); + + let input = hex!("aa"); + let mut cursor = Cursor::new(input); + let mut reader = Reader::new(&mut cursor); assert!(!reader.end()); - let _ = reader.read_bits(4); + let mut buf = [0; 1]; + let _ = reader.read_bytes(1, &mut buf).unwrap(); assert!(reader.end()); + assert_eq!(reader.bits_read, 8); } #[test] @@ -520,9 +534,22 @@ mod tests { let input = hex!("aa"); let mut cursor = Cursor::new(input); let mut reader = Reader::new(&mut cursor); - let _ = reader.read_bits(1); - let _ = reader.read_bits(4); - let _ = reader.read_bits(3); + let _ = reader.read_bits(1).unwrap(); + let _ = reader.read_bits(4).unwrap(); + let _ = reader.read_bits(3).unwrap(); + assert_eq!(reader.bits_read, 8); + } + + #[test] + #[cfg(feature = "bits")] + fn test_bits_skip() { + let input = hex!("aa"); + let mut cursor = Cursor::new(input); + let mut reader = Reader::new(&mut cursor); + let _ = reader.skip_bits(1).unwrap(); + let _ = reader.skip_bits(4).unwrap(); + let _ = reader.skip_bits(3).unwrap(); + assert_eq!(reader.bits_read, 8); } #[test] @@ -531,8 +558,9 @@ mod tests { let mut cursor = Cursor::new(input); let mut reader = Reader::new(&mut cursor); let mut buf = [0; 1]; - let _ = reader.read_bytes(1, &mut buf); + let _ = reader.read_bytes(1, &mut buf).unwrap(); assert_eq!([0xaa], buf); + assert_eq!(reader.bits_read, 8); } #[test] @@ -542,22 +570,28 @@ mod tests { let mut cursor = Cursor::new(input); let mut reader = Reader::new(&mut cursor); let mut buf = [0; 1]; - let _ = reader.read_bytes(1, &mut buf); + let _ = reader.read_bytes(1, &mut buf).unwrap(); assert_eq!([0xaa], buf); + assert_eq!(reader.bits_read, 8); + reader.seek_last_read().unwrap(); - let _ = reader.read_bytes(1, &mut buf); + let _ = reader.read_bytes(1, &mut buf).unwrap(); assert_eq!([0xaa], buf); + assert_eq!(reader.bits_read, 8); // 2 bytes (and const) let input = hex!("aabb"); let mut cursor = Cursor::new(input); let mut reader = Reader::new(&mut cursor); let mut buf = [0; 2]; - let _ = reader.read_bytes_const::<2>(&mut buf); + let _ = reader.read_bytes_const::<2>(&mut buf).unwrap(); assert_eq!([0xaa, 0xbb], buf); + assert_eq!(reader.bits_read, 16); + reader.seek_last_read().unwrap(); - let _ = reader.read_bytes_const::<2>(&mut buf); + let _ = reader.read_bytes_const::<2>(&mut buf).unwrap(); assert_eq!([0xaa, 0xbb], buf); + assert_eq!(reader.bits_read, 16); } #[cfg(feature = "bits")] @@ -568,9 +602,11 @@ mod tests { let mut reader = Reader::new(&mut cursor); let bits = reader.read_bits(4).unwrap(); assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0])); + assert_eq!(reader.bits_read, 4); reader.seek_last_read().unwrap(); let bits = reader.read_bits(4).unwrap(); assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0])); + assert_eq!(reader.bits_read, 4); // more than byte let input = hex!("abd0"); @@ -578,8 +614,10 @@ mod tests { let mut reader = Reader::new(&mut cursor); let bits = reader.read_bits(9).unwrap(); assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0, 1, 0, 1, 1, 1])); + assert_eq!(reader.bits_read, 9); reader.seek_last_read().unwrap(); let bits = reader.read_bits(9).unwrap(); assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0, 1, 0, 1, 1, 1])); + assert_eq!(reader.bits_read, 9); } }