Skip to content

Commit

Permalink
reader: Remove bits_read for end
Browse files Browse the repository at this point in the history
* Remove bits_read from end, as it's not really read, and we
  need to count the read when the leftover is read later.
* Fix missing bits_read increase on previous read
* Improve test coverage of reader by adding tests, and add more asserts
  • Loading branch information
wcampbell0x2a committed Dec 1, 2024
1 parent 7592eaf commit 1fb5e20
Showing 1 changed file with 53 additions and 15 deletions.
68 changes: 53 additions & 15 deletions src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> {
/// Return true if we are at the end of a reader and there are no cached bits in the reader.
/// Since this uses [Read] internally, this will return true when [Read] returns [ErrorKind::UnexpectedEof].
///
/// The byte that was read will be internally buffered
/// The byte that was read will be internally buffered and will *not* be included in the `bits_read` count.
#[inline]
pub fn end(&mut self) -> bool {
if self.leftover.is_some() {
Expand All @@ -160,13 +160,14 @@ impl<'a, R: Read + Seek> Reader<'a, R> {
#[cfg(feature = "logging")]
log::trace!("not end: read {:02x?}", &buf);

self.bits_read += 8;
self.leftover = Some(Leftover::Byte(buf[0]));
false
}
}

/// Used at the beginning of `from_reader`.
///
/// This will increment `bits_read`.
// TODO: maybe send into read_bytes() if amt >= 8
#[inline]
pub fn skip_bits(&mut self, amt: usize) -> Result<(), DekuError> {
Expand All @@ -186,7 +187,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> {
.map_err(|e| DekuError::Io(e.kind()))?;
}

// Unlike normal seek not couting as bits_read, this one does
// Unlike normal seek not counting as bits_read, this one does
// to keep from_bytes returns
self.bits_read += bytes_amt * 8;

Expand Down Expand Up @@ -374,6 +375,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> {
#[cfg(feature = "logging")]
log::trace!("read_bytes_const_leftover: returning {:02x?}", &buf);

self.bits_read += amt * 8;
return Ok(ReaderRet::Bytes);
}
let buf_len = buf.len();
Expand All @@ -389,7 +391,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> {
}
return Err(DekuError::Io(e.kind()));
}
self.bits_read += remaining * 8;
self.bits_read += amt * 8;

#[cfg(feature = "logging")]
log::trace!("read_bytes_leftover: returning {:02x?}", &buf);
Expand Down Expand Up @@ -462,6 +464,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> {
if remaining == 0 {
#[cfg(feature = "logging")]
log::trace!("read_bytes_const_leftover: returning {:02x?}", &buf);
self.bits_read += N * 8;

return Ok(ReaderRet::Bytes);
}
Expand All @@ -478,7 +481,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> {
}
return Err(DekuError::Io(e.kind()));
}
self.bits_read += remaining * 8;
self.bits_read += N * 8;

#[cfg(feature = "logging")]
log::trace!("read_bytes_const_leftover: returning {:02x?}", &buf);
Expand All @@ -503,15 +506,26 @@ mod tests {
let mut buf = [0; 2];
let _ = reader.read_bytes_const::<2>(&mut buf).unwrap();
assert!(reader.end());
assert_eq!(reader.bits_read, 8 * 2);

let input = hex!("aa");
let mut cursor = Cursor::new(input);
let mut reader = Reader::new(&mut cursor);
assert!(!reader.end());
let _ = reader.read_bits(4);
let _ = reader.read_bits(4).unwrap();
assert!(!reader.end());
let _ = reader.read_bits(4).unwrap();
assert!(reader.end());
assert_eq!(reader.bits_read, 8);

let input = hex!("aa");
let mut cursor = Cursor::new(input);
let mut reader = Reader::new(&mut cursor);
assert!(!reader.end());
let _ = reader.read_bits(4);
let mut buf = [0; 1];
let _ = reader.read_bytes(1, &mut buf).unwrap();
assert!(reader.end());
assert_eq!(reader.bits_read, 8);
}

#[test]
Expand All @@ -520,9 +534,22 @@ mod tests {
let input = hex!("aa");
let mut cursor = Cursor::new(input);
let mut reader = Reader::new(&mut cursor);
let _ = reader.read_bits(1);
let _ = reader.read_bits(4);
let _ = reader.read_bits(3);
let _ = reader.read_bits(1).unwrap();
let _ = reader.read_bits(4).unwrap();
let _ = reader.read_bits(3).unwrap();
assert_eq!(reader.bits_read, 8);
}

#[test]
#[cfg(feature = "bits")]
fn test_bits_skip() {
let input = hex!("aa");
let mut cursor = Cursor::new(input);
let mut reader = Reader::new(&mut cursor);
let _ = reader.skip_bits(1).unwrap();
let _ = reader.skip_bits(4).unwrap();
let _ = reader.skip_bits(3).unwrap();
assert_eq!(reader.bits_read, 8);
}

#[test]
Expand All @@ -531,8 +558,9 @@ mod tests {
let mut cursor = Cursor::new(input);
let mut reader = Reader::new(&mut cursor);
let mut buf = [0; 1];
let _ = reader.read_bytes(1, &mut buf);
let _ = reader.read_bytes(1, &mut buf).unwrap();
assert_eq!([0xaa], buf);
assert_eq!(reader.bits_read, 8);
}

#[test]
Expand All @@ -542,22 +570,28 @@ mod tests {
let mut cursor = Cursor::new(input);
let mut reader = Reader::new(&mut cursor);
let mut buf = [0; 1];
let _ = reader.read_bytes(1, &mut buf);
let _ = reader.read_bytes(1, &mut buf).unwrap();
assert_eq!([0xaa], buf);
assert_eq!(reader.bits_read, 8);

reader.seek_last_read().unwrap();
let _ = reader.read_bytes(1, &mut buf);
let _ = reader.read_bytes(1, &mut buf).unwrap();
assert_eq!([0xaa], buf);
assert_eq!(reader.bits_read, 8);

// 2 bytes (and const)
let input = hex!("aabb");
let mut cursor = Cursor::new(input);
let mut reader = Reader::new(&mut cursor);
let mut buf = [0; 2];
let _ = reader.read_bytes_const::<2>(&mut buf);
let _ = reader.read_bytes_const::<2>(&mut buf).unwrap();
assert_eq!([0xaa, 0xbb], buf);
assert_eq!(reader.bits_read, 16);

reader.seek_last_read().unwrap();
let _ = reader.read_bytes_const::<2>(&mut buf);
let _ = reader.read_bytes_const::<2>(&mut buf).unwrap();
assert_eq!([0xaa, 0xbb], buf);
assert_eq!(reader.bits_read, 16);
}

#[cfg(feature = "bits")]
Expand All @@ -568,18 +602,22 @@ mod tests {
let mut reader = Reader::new(&mut cursor);
let bits = reader.read_bits(4).unwrap();
assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0]));
assert_eq!(reader.bits_read, 4);
reader.seek_last_read().unwrap();
let bits = reader.read_bits(4).unwrap();
assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0]));
assert_eq!(reader.bits_read, 4);

// more than byte
let input = hex!("abd0");
let mut cursor = Cursor::new(input);
let mut reader = Reader::new(&mut cursor);
let bits = reader.read_bits(9).unwrap();
assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0, 1, 0, 1, 1, 1]));
assert_eq!(reader.bits_read, 9);
reader.seek_last_read().unwrap();
let bits = reader.read_bits(9).unwrap();
assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0, 1, 0, 1, 1, 1]));
assert_eq!(reader.bits_read, 9);
}
}

0 comments on commit 1fb5e20

Please sign in to comment.