Skip to content

Commit

Permalink
Seek first bytes lengths before seek bits for __deku_input (#502)
Browse files Browse the repository at this point in the history
* Seek first bytes lengths before seek bits for __deku_input

* For bits input into from_bytes, first seek the byte amount until we can
 seek the bits amount

Closes #500

* deku-derive: Add expect to seeking
  • Loading branch information
wcampbell0x2a authored Nov 5, 2024
1 parent 35b1d44 commit 2b24a56
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 7 deletions.
9 changes: 6 additions & 3 deletions deku-derive/src/macros/deku_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ fn emit_struct(input: &DekuData) -> Result<TokenStream, syn::Error> {
{
use ::#crate_::no_std_io::Seek;
use ::#crate_::no_std_io::SeekFrom;
if let Err(e) = __deku_reader.seek(SeekFrom::Current(i64::try_from(#num).unwrap())) {
let seek_amt = i64::try_from(#num).expect("could not convert into i64");
if let Err(e) = __deku_reader.seek(SeekFrom::Current(seek_amt)) {
return Err(::#crate_::DekuError::Io(e.kind()));
}
}
Expand All @@ -52,7 +53,8 @@ fn emit_struct(input: &DekuData) -> Result<TokenStream, syn::Error> {
{
use ::#crate_::no_std_io::Seek;
use ::#crate_::no_std_io::SeekFrom;
if let Err(e) = __deku_reader.seek(SeekFrom::End(i64::try_from(#num).unwrap())) {
let seek_amt = i64::try_from(#num).expect("could not convert into i64");
if let Err(e) = __deku_reader.seek(SeekFrom::End(seek_amt)) {
return Err(::#crate_::DekuError::Io(e.kind()));
}
}
Expand All @@ -62,7 +64,8 @@ fn emit_struct(input: &DekuData) -> Result<TokenStream, syn::Error> {
{
use ::#crate_::no_std_io::Seek;
use ::#crate_::no_std_io::SeekFrom;
if let Err(e) = __deku_reader.seek(SeekFrom::Start(u64::try_from(#num).unwrap())) {
let seek_amt = u64::try_from(#num).expect("could not convert into u64");
if let Err(e) = __deku_reader.seek(SeekFrom::Start(seek_amt)) {
return Err(::#crate_::DekuError::Io(e.kind()));
}
}
Expand Down
9 changes: 6 additions & 3 deletions deku-derive/src/macros/deku_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ fn emit_struct(input: &DekuData) -> Result<TokenStream, syn::Error> {
{
use ::#crate_::no_std_io::Seek;
use ::#crate_::no_std_io::SeekFrom;
if let Err(e) = __deku_writer.seek(SeekFrom::Current(i64::try_from(#num).unwrap())) {
let seek_amt = i64::try_from(#num).expect("could not convert into i64");
if let Err(e) = __deku_writer.seek(SeekFrom::Current(seek_amt)) {
return Err(::#crate_::DekuError::Io(e.kind()));
}
}
Expand All @@ -43,7 +44,8 @@ fn emit_struct(input: &DekuData) -> Result<TokenStream, syn::Error> {
{
use ::#crate_::no_std_io::Seek;
use ::#crate_::no_std_io::SeekFrom;
if let Err(e) = __deku_writer.seek(SeekFrom::End(i64::try_from(#num).unwrap())) {
let seek_amt = i64::try_from(#num).expect("could not convert into i64");
if let Err(e) = __deku_writer.seek(SeekFrom::End(seek_amt)) {
return Err(::#crate_::DekuError::Io(e.kind()));
}
}
Expand All @@ -53,7 +55,8 @@ fn emit_struct(input: &DekuData) -> Result<TokenStream, syn::Error> {
{
use ::#crate_::no_std_io::Seek;
use ::#crate_::no_std_io::SeekFrom;
if let Err(e) = __deku_writer.seek(SeekFrom::Start(u64::try_from(#num).unwrap())) {
let seek_amt = u64::try_from(#num).expect("could not convert into u64");
if let Err(e) = __deku_writer.seek(SeekFrom::Start(seek_amt)) {
return Err(::#crate_::DekuError::Io(e.kind()));
}
}
Expand Down
18 changes: 17 additions & 1 deletion src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,24 @@ impl<'a, R: Read + Seek> Reader<'a, R> {
{
#[cfg(feature = "logging")]
log::trace!("skip_bits: {amt}");

let bytes_amt = amt / 8;
let bits_amt = amt % 8;

// first, seek with bytes
if bytes_amt != 0 {
self.seek(SeekFrom::Current(
i64::try_from(bytes_amt).expect("could not convert seek usize into i64"),
))
.map_err(|e| DekuError::Io(e.kind()))?;
}

// Unlike normal seek not couting as bits_read, this one does
// to keep from_bytes returns
self.bits_read += bytes_amt * 8;

// Save, and keep the leftover bits since the read will most likely be less than a byte
self.read_bits(amt)?;
self.read_bits(bits_amt)?;
}

#[cfg(not(feature = "bits"))]
Expand Down
27 changes: 27 additions & 0 deletions tests/test_from_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,30 @@ fn test_from_bytes_enum() {
assert_eq!(6, i);
assert_eq!(0b0101_1010u8, rest[0]);
}

#[cfg(feature = "bits")]
#[test]
fn test_from_bytes_long() {
#[derive(Debug, PartialEq, DekuRead, DekuWrite)]
#[deku(id_type = "u8", bits = 4)]
enum TestDeku {
#[deku(id = "0b0110")]
VariantA(#[deku(bits = 4)] u8),
#[deku(id = "0b0101")]
VariantB(#[deku(bits = 2)] u8),
}

let mut test_data = vec![0x00; 200];
test_data.extend([0b0110_0110u8, 0b0101_1010u8].to_vec());

let ((rest, i), ret_read) = TestDeku::from_bytes((&test_data, 200 * 8)).unwrap();
assert_eq!(TestDeku::VariantA(0b0110), ret_read);
assert_eq!(1, rest.len());
assert_eq!(0, i);

let ((rest, i), ret_read) = TestDeku::from_bytes((rest, i)).unwrap();
assert_eq!(TestDeku::VariantB(0b10), ret_read);
assert_eq!(1, rest.len());
assert_eq!(6, i);
assert_eq!(0b0101_1010u8, rest[0]);
}

0 comments on commit 2b24a56

Please sign in to comment.