diff --git a/deku-derive/src/lib.rs b/deku-derive/src/lib.rs index a2a0de3f..af54f686 100644 --- a/deku-derive/src/lib.rs +++ b/deku-derive/src/lib.rs @@ -125,6 +125,8 @@ struct DekuData { /// default context passed to the field ctx_default: Option>, + update_ctx: Option>, + /// A magic value that must appear at the start of this struct/enum's data magic: Option, @@ -182,6 +184,7 @@ impl DekuData { endian: receiver.endian, ctx: receiver.ctx, ctx_default: receiver.ctx_default, + update_ctx: receiver.update_ctx, magic: receiver.magic, id: receiver.id, id_type: receiver.id_type?, @@ -402,6 +405,13 @@ struct FieldData { /// map field when updating struct update: Option, + /// map field when updating struct + update_custom: Option, + + update_ctx: Option>, + + call_update: bool, + /// custom field reader code reader: Option, @@ -450,6 +460,12 @@ impl FieldData { .transpose() .map_err(|e| e.to_compile_error())?; + let update_ctx = receiver + .update_ctx? + .map(|s| s.parse_with(Punctuated::parse_terminated)) + .transpose() + .map_err(|e| e.to_compile_error())?; + let data = Self { ident: receiver.ident, ty: receiver.ty, @@ -464,6 +480,9 @@ impl FieldData { map: receiver.map?, ctx, update: receiver.update?, + update_custom: receiver.update_custom?, + update_ctx, + call_update: receiver.call_update, reader: receiver.reader?, writer: receiver.writer?, skip: receiver.skip, @@ -645,6 +664,9 @@ struct DekuReceiver { #[darling(default)] ctx_default: Option>, + #[darling(default)] + update_ctx: Option>, + /// A magic value that must appear at the start of this struct/enum's data #[darling(default)] magic: Option, @@ -797,6 +819,19 @@ struct DekuFieldReceiver { #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")] update: Result, ReplacementError>, + /// map field when updating struct + #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")] + update_custom: Result, ReplacementError>, + + /// skip field reading/writing + #[darling(default)] + call_update: bool, + + // TODO: The type of it should be `Punctuated` + // https://github.com/TedDriggs/darling/pull/98 + #[darling(default = "default_res_opt", map = "map_option_litstr")] + update_ctx: Result, ReplacementError>, + /// custom field reader code #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")] reader: Result, ReplacementError>, diff --git a/deku-derive/src/macros/deku_write.rs b/deku-derive/src/macros/deku_write.rs index 6e0b7154..9dc56515 100644 --- a/deku-derive/src/macros/deku_write.rs +++ b/deku-derive/src/macros/deku_write.rs @@ -97,6 +97,7 @@ fn emit_struct(input: &DekuData) -> Result { } let (ctx_types, ctx_arg) = gen_ctx_types_and_arg(input.ctx.as_ref())?; + let (update_ctx_types, update_ctx_arg) = gen_ctx_types_and_arg(input.update_ctx.as_ref())?; let write_body = quote! { match *self { @@ -113,8 +114,8 @@ fn emit_struct(input: &DekuData) -> Result { let update_use = check_update_use(&field_updates); tokens.extend(quote! { - impl #imp DekuUpdate for #ident #wher { - fn update(&mut self) -> core::result::Result<(), ::#crate_::DekuError> { + impl #imp DekuUpdate<#update_ctx_types> for #ident #wher { + fn update(&mut self, #update_ctx_arg) -> core::result::Result<(), ::#crate_::DekuError> { #update_use #(#field_updates)* @@ -315,6 +316,7 @@ fn emit_enum(input: &DekuData) -> Result { } let (ctx_types, ctx_arg) = gen_ctx_types_and_arg(input.ctx.as_ref())?; + let (update_ctx_types, update_ctx_arg) = gen_ctx_types_and_arg(input.update_ctx.as_ref())?; let write_body = quote! { #magic_write @@ -330,8 +332,8 @@ fn emit_enum(input: &DekuData) -> Result { let update_use = check_update_use(&variant_updates); tokens.extend(quote! { - impl #imp DekuUpdate for #ident #wher { - fn update(&mut self) -> core::result::Result<(), ::#crate_::DekuError> { + impl #imp DekuUpdate<#update_ctx_types> for #ident #wher { + fn update(&mut self, #update_ctx_arg) -> core::result::Result<(), ::#crate_::DekuError> { #update_use match self { @@ -410,17 +412,34 @@ fn emit_field_update( return None; } let field_ident = f.get_ident(i, object_prefix.is_none()); - let deref = if object_prefix.is_none() { - Some(quote! { * }) - } else { - None + let deref = match object_prefix.is_none() { + true => Some(quote! { * }), + false => None, }; - f.update.as_ref().map(|field_update| { - quote! { - #deref #object_prefix #field_ident = (#field_update).try_into()?; + if f.call_update { + let ctx = match f.update_ctx.as_ref() { + Some(ctx) => quote! {(#ctx)}, + None => quote! {()}, + }; + let a = quote! { + #deref #object_prefix #field_ident.update(#ctx)?; + }; + + Some(a) + } else { + if let Some(custom) = &f.update_custom { + Some(quote! { + #custom(&mut #object_prefix #field_ident)?; + }) + } else { + f.update.as_ref().map(|field_update| { + quote! { + #deref #object_prefix #field_ident = (#field_update).try_into()?; + } + }) } - }) + } } fn emit_bit_byte_offsets( diff --git a/examples/update.rs b/examples/update.rs new file mode 100644 index 00000000..ebd52e4e --- /dev/null +++ b/examples/update.rs @@ -0,0 +1,68 @@ +use deku::prelude::*; + +#[derive(Debug, DekuRead, DekuWrite, PartialEq)] +pub struct Test { + #[deku(call_update, update_ctx = "self.val.len() as u16, 0")] + hdr: Hdr, + + #[deku(count = "hdr.length")] + val: Vec, + + #[deku(call_update)] + no_update_ctx: NoUpdateCtx, + + #[deku(update_custom = "Self::custom")] + num: u8, + + #[deku(update_custom = "Self::other_custom")] + other_num: (u8, u32), +} + +impl Test { + fn custom(num: &mut u8) -> Result<(), DekuError> { + *num = 1; + + Ok(()) + } + + fn other_custom(num: &mut (u8, u32)) -> Result<(), DekuError> { + *num = (0xf0, 0x0f); + + Ok(()) + } +} + +#[derive(Debug, DekuRead, DekuWrite, PartialEq)] +#[deku(update_ctx = "val_len: u16, _na: u8")] +struct Hdr { + #[deku(update = "val_len")] + length: u8, +} + +#[derive(Debug, DekuRead, DekuWrite, PartialEq)] +struct NoUpdateCtx { + #[deku(update = "0xff")] + val: u8, +} + +fn main() { + let mut test = Test { + hdr: Hdr { length: 2 }, + val: vec![1, 2], + no_update_ctx: NoUpdateCtx { val: 0 }, + num: 0, + other_num: (0, 0), + }; + + test.val = vec![1, 2, 3]; + test.update(()).unwrap(); + + let expected = Test { + hdr: Hdr { length: 3 }, + val: test.val.clone(), + no_update_ctx: NoUpdateCtx { val: 0xff }, + num: 1, + other_num: (0xf0, 0x0f), + }; + assert_eq!(expected, test); +} diff --git a/src/attributes.rs b/src/attributes.rs index ee596a00..625fd3cd 100644 --- a/src/attributes.rs +++ b/src/attributes.rs @@ -483,7 +483,7 @@ assert_eq!( value.items.push(0xFF); // update it, this will update the `count` field -value.update().unwrap(); +value.update(()).unwrap(); assert_eq!( DekuTest { count: 0x03, items: vec![0xAB, 0xCD, 0xFF] }, diff --git a/src/lib.rs b/src/lib.rs index 64d36de5..b0e95448 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,7 +138,7 @@ let data_out = val.to_bytes().unwrap(); assert_eq!(vec![0x02, 0xBE, 0xEF, 0xAA], data_out); // Use `update` to update `count` -val.update().unwrap(); +val.update(()).unwrap(); assert_eq!(DekuTest { count: 0x03, @@ -446,9 +446,9 @@ pub trait DekuContainerWrite: DekuWrite<()> { } /// "Updater" trait: apply mutations to a type -pub trait DekuUpdate { +pub trait DekuUpdate { /// Apply updates - fn update(&mut self) -> Result<(), DekuError>; + fn update(&mut self, ctx: Ctx) -> Result<(), DekuError>; } /// "Extended Enum" trait: obtain additional enum information diff --git a/tests/test_attributes/test_update.rs b/tests/test_attributes/test_update.rs index c8feec49..7a207867 100644 --- a/tests/test_attributes/test_update.rs +++ b/tests/test_attributes/test_update.rs @@ -18,7 +18,7 @@ fn test_update() { assert_eq!(TestStruct { field_a: 0x01 }, ret_read); // `field_a` field should now be increased - ret_read.update().unwrap(); + ret_read.update(()).unwrap(); assert_eq!(0x05, ret_read.field_a); let ret_write: Vec = ret_read.try_into().unwrap(); @@ -53,7 +53,7 @@ fn test_update_from_field() { ret_read.data.push(0xff); // `count` field should now be increased - ret_read.update().unwrap(); + ret_read.update(()).unwrap(); assert_eq!(3, ret_read.count); // Write @@ -75,5 +75,5 @@ fn test_update_error() { let mut val = TestStruct { count: 0x01 }; - val.update().unwrap(); + val.update(()).unwrap(); }