diff --git a/CHANGELOG.md b/CHANGELOG.md index 0dc280f1bc..02e376531a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ The minor version will be incremented upon a breaking change and the patch versi - lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)). - lang: Use associated discriminator constants instead of hardcoding in `#[account]` ([#3144](https://github.com/coral-xyz/anchor/pull/3144)). - lang: Add `discriminator` argument to `#[account]` attribute ([#3149](https://github.com/coral-xyz/anchor/pull/3149)). +- lang: Add `discriminator` argument to `#[event]` attribute ([#3152](https://github.com/coral-xyz/anchor/pull/3152)). ### Fixes diff --git a/lang/attribute/event/src/lib.rs b/lang/attribute/event/src/lib.rs index c637938769..77d6b99906 100644 --- a/lang/attribute/event/src/lib.rs +++ b/lang/attribute/event/src/lib.rs @@ -2,29 +2,52 @@ extern crate proc_macro; #[cfg(feature = "event-cpi")] use anchor_syn::parser::accounts::event_cpi::{add_event_cpi_accounts, EventAuthority}; -use quote::quote; -use syn::parse_macro_input; +use quote::{quote, ToTokens}; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, + token::Comma, + Expr, Ident, Lit, Token, +}; /// The event attribute allows a struct to be used with /// [emit!](./macro.emit.html) so that programs can log significant events in /// their programs that clients can subscribe to. Currently, this macro is for /// structs only. /// +/// # Args +/// +/// - `discriminator`: Override the default 8-byte discriminator +/// +/// **Usage:** `discriminator = ` +/// +/// All constant expressions are supported. +/// +/// **Examples:** +/// +/// - `discriminator = 0` (shortcut for `[0]`) +/// - `discriminator = [1, 2, 3, 4]` +/// - `discriminator = b"hi"` +/// - `discriminator = MY_DISC` +/// - `discriminator = get_disc(...)` +/// /// See the [`emit!` macro](emit!) for an example. #[proc_macro_attribute] pub fn event( - _args: proc_macro::TokenStream, + args: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { + let args = parse_macro_input!(args as EventArgs); let event_strct = parse_macro_input!(input as syn::ItemStruct); - let event_name = &event_strct.ident; - let discriminator: proc_macro2::TokenStream = { + let discriminator = args.discriminator.unwrap_or_else(|| { let discriminator_preimage = format!("event:{event_name}").into_bytes(); let discriminator = anchor_syn::hash::hash(&discriminator_preimage); - format!("{:?}", &discriminator.0[..8]).parse().unwrap() - }; + let discriminator: proc_macro2::TokenStream = + format!("{:?}", &discriminator.0[..8]).parse().unwrap(); + quote! { &#discriminator } + }); let ret = quote! { #[derive(anchor_lang::__private::EventIndex, AnchorSerialize, AnchorDeserialize)] @@ -33,14 +56,14 @@ pub fn event( impl anchor_lang::Event for #event_name { fn data(&self) -> Vec { let mut data = Vec::with_capacity(256); - data.extend_from_slice(&#discriminator); + data.extend_from_slice(#event_name::DISCRIMINATOR); self.serialize(&mut data).unwrap(); data } } impl anchor_lang::Discriminator for #event_name { - const DISCRIMINATOR: &'static [u8] = &#discriminator; + const DISCRIMINATOR: &'static [u8] = #discriminator; } }; @@ -57,6 +80,54 @@ pub fn event( proc_macro::TokenStream::from(ret) } +#[derive(Debug, Default)] +struct EventArgs { + /// Discriminator override + discriminator: Option, +} + +impl Parse for EventArgs { + fn parse(input: ParseStream) -> syn::Result { + // TODO: Share impl with `#[instruction]` + let mut parsed = Self::default(); + let args = input.parse_terminated::<_, Comma>(EventArg::parse)?; + for arg in args { + match arg.name.to_string().as_str() { + "discriminator" => { + let value = match &arg.value { + // Allow `discriminator = 42` + Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] }, + // Allow `discriminator = [0, 1, 2, 3]` + Expr::Array(arr) => quote! { &#arr }, + expr => expr.to_token_stream(), + }; + parsed.discriminator.replace(value); + } + _ => return Err(syn::Error::new(arg.name.span(), "Invalid argument")), + } + } + + Ok(parsed) + } +} + +struct EventArg { + name: Ident, + #[allow(dead_code)] + eq_token: Token![=], + value: Expr, +} + +impl Parse for EventArg { + fn parse(input: ParseStream) -> syn::Result { + Ok(Self { + name: input.parse()?, + eq_token: input.parse()?, + value: input.parse()?, + }) + } +} + // EventIndex is a marker macro. It functionally does nothing other than // allow one to mark fields with the `#[index]` inert attribute, which is // used to add metadata to IDLs. diff --git a/tests/custom-discriminator/programs/custom-discriminator/src/lib.rs b/tests/custom-discriminator/programs/custom-discriminator/src/lib.rs index cdcc479375..529ae3fe31 100644 --- a/tests/custom-discriminator/programs/custom-discriminator/src/lib.rs +++ b/tests/custom-discriminator/programs/custom-discriminator/src/lib.rs @@ -44,6 +44,11 @@ pub mod custom_discriminator { ctx.accounts.my_account.field = field; Ok(()) } + + pub fn event(_ctx: Context, field: u8) -> Result<()> { + emit!(MyEvent { field }); + Ok(()) + } } #[derive(Accounts)] @@ -70,3 +75,8 @@ pub struct CustomAccountIx<'info> { pub struct MyAccount { pub field: u8, } + +#[event(discriminator = 1)] +pub struct MyEvent { + field: u8, +} diff --git a/tests/custom-discriminator/tests/custom-discriminator.ts b/tests/custom-discriminator/tests/custom-discriminator.ts index f0d5fc1af8..b87e20594a 100644 --- a/tests/custom-discriminator/tests/custom-discriminator.ts +++ b/tests/custom-discriminator/tests/custom-discriminator.ts @@ -47,7 +47,26 @@ describe("custom-discriminator", () => { const myAccount = await program.account.myAccount.fetch( pubkeys.myAccount ); - assert.strictEqual(field, myAccount.field); + assert.strictEqual(myAccount.field, field); + }); + }); + + describe("Events", () => { + it("Works", async () => { + // Verify discriminator + const event = program.idl.events.find((acc) => acc.name === "myEvent")!; + assert(event.discriminator.length < 8); + + // Verify regular event works + await new Promise((res) => { + const field = 5; + const id = program.addEventListener("myEvent", (ev) => { + assert.strictEqual(ev.field, field); + program.removeEventListener(id); + res(); + }); + program.methods.event(field).rpc(); + }); }); }); });