Skip to content

Commit

Permalink
lang: Add #[instruction] attribute proc-macro (#3137)
Browse files Browse the repository at this point in the history
  • Loading branch information
acheroncrypto authored Jul 30, 2024
1 parent 5a20cd9 commit 3f945f6
Show file tree
Hide file tree
Showing 16 changed files with 292 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/reusable-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,8 @@ jobs:
path: tests/safety-checks
- cmd: cd tests/custom-coder && anchor test --skip-lint && npx tsc --noEmit
path: tests/custom-coder
- cmd: cd tests/custom-discriminator && anchor test && npx tsc --noEmit
path: tests/custom-discriminator
- cmd: cd tests/validator-clone && anchor test --skip-lint && npx tsc --noEmit
path: tests/validator-clone
- cmd: cd tests/cpi-returns && anchor test --skip-lint && npx tsc --noEmit
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- ts: Add optional `wallet` property to the `Provider` interface ([#3130](https://github.com/coral-xyz/anchor/pull/3130)).
- cli: Warn if `anchor-spl/idl-build` is missing ([#3133](https://github.com/coral-xyz/anchor/pull/3133)).
- client: Add `internal_rpc` method for `mock` feature ([#3135](https://github.com/coral-xyz/anchor/pull/3135)).
- lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)).

### Fixes

Expand Down
50 changes: 50 additions & 0 deletions lang/attribute/program/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,53 @@ pub fn interface(
// discriminator.
input
}

/// This attribute is used to override the Anchor defaults of program instructions.
///
/// # Args
///
/// - `discriminator`: Override the default 8-byte discriminator
///
/// **Usage:** `discriminator = <CONST_EXPR>`
///
/// All constant expressions are supported.
///
/// **Examples:**
///
/// - `discriminator = 0` (shortcut for `[0]`)
/// - `discriminator = [1, 2, 3, 4]`
/// - `discriminator = b"hi"`
/// - `discriminator = MY_DISC`
/// - `discriminator = get_disc(...)`
///
/// # Example
///
/// ```ignore
/// use anchor_lang::prelude::*;
///
/// declare_id!("CustomDiscriminator111111111111111111111111");
///
/// #[program]
/// pub mod custom_discriminator {
/// use super::*;
///
/// #[instruction(discriminator = [1, 2, 3, 4])]
/// pub fn my_ix(_ctx: Context<MyIx>) -> Result<()> {
/// Ok(())
/// }
/// }
///
/// #[derive(Accounts)]
/// pub struct MyIx<'info> {
/// pub signer: Signer<'info>,
/// }
/// ```
#[proc_macro_attribute]
pub fn instruction(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
// This macro itself is a no-op, but the `#[program]` macro will detect this attribute and use
// the arguments to transform the instruction.
input
}
6 changes: 3 additions & 3 deletions lang/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub use anchor_attribute_account::{account, declare_id, pubkey, zero_copy};
pub use anchor_attribute_constant::constant;
pub use anchor_attribute_error::*;
pub use anchor_attribute_event::{emit, event};
pub use anchor_attribute_program::{declare_program, program};
pub use anchor_attribute_program::{declare_program, instruction, program};
pub use anchor_derive_accounts::Accounts;
pub use anchor_derive_serde::{AnchorDeserialize, AnchorSerialize};
pub use anchor_derive_space::InitSpace;
Expand Down Expand Up @@ -392,8 +392,8 @@ pub mod prelude {
accounts::signer::Signer, accounts::system_account::SystemAccount,
accounts::sysvar::Sysvar, accounts::unchecked_account::UncheckedAccount, constant,
context::Context, context::CpiContext, declare_id, declare_program, emit, err, error,
event, program, pubkey, require, require_eq, require_gt, require_gte, require_keys_eq,
require_keys_neq, require_neq,
event, instruction, program, pubkey, require, require_eq, require_gt, require_gte,
require_keys_eq, require_keys_neq, require_neq,
solana_program::bpf_loader_upgradeable::UpgradeableLoaderState, source,
system_program::System, zero_copy, AccountDeserialize, AccountSerialize, Accounts,
AccountsClose, AccountsExit, AnchorDeserialize, AnchorSerialize, Discriminator, Id,
Expand Down
28 changes: 19 additions & 9 deletions lang/syn/src/codegen/program/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,25 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
.unwrap()
})
.collect();
let ix_data_trait = {
let discriminator = ix
.interface_discriminator
.unwrap_or_else(|| sighash(SIGHASH_GLOBAL_NAMESPACE, name));
let discriminator: proc_macro2::TokenStream =
format!("{discriminator:?}").parse().unwrap();
let impls = {
let discriminator = match ix.ix_attr.as_ref() {
Some(ix_attr) if ix_attr.discriminator.is_some() => {
ix_attr.discriminator.as_ref().unwrap().to_owned()
}
_ => {
// TODO: Remove `interface_discriminator`
let discriminator = ix
.interface_discriminator
.unwrap_or_else(|| sighash(SIGHASH_GLOBAL_NAMESPACE, name));
let discriminator: proc_macro2::TokenStream =
format!("{discriminator:?}").parse().unwrap();
quote! { &#discriminator }
}
};

quote! {
impl anchor_lang::Discriminator for #ix_name_camel {
const DISCRIMINATOR: &'static [u8] = &#discriminator;
const DISCRIMINATOR: &'static [u8] = #discriminator;
}
impl anchor_lang::InstructionData for #ix_name_camel {}
impl anchor_lang::Owner for #ix_name_camel {
Expand All @@ -46,7 +56,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct #ix_name_camel;

#ix_data_trait
#impls
}
} else {
quote! {
Expand All @@ -56,7 +66,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
#(#raw_args),*
}

#ix_data_trait
#impls
}
}
})
Expand Down
52 changes: 52 additions & 0 deletions lang/syn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use syn::parse::{Error as ParseError, Parse, ParseStream, Result as ParseResult}
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::Lit;
use syn::{
Expr, Generics, Ident, ItemEnum, ItemFn, ItemMod, ItemStruct, LitInt, PatType, Token, Type,
TypePath,
Expand Down Expand Up @@ -68,7 +69,58 @@ pub struct Ix {
// The ident for the struct deriving Accounts.
pub anchor_ident: Ident,
// The discriminator based on the `#[interface]` attribute.
// TODO: Remove and use `ix_attr`
pub interface_discriminator: Option<[u8; 8]>,
/// `#[instruction]` attribute
pub ix_attr: Option<IxAttr>,
}

/// `#[instruction]` attribute proc-macro
#[derive(Debug, Default)]
pub struct IxAttr {
/// Discriminator override
pub discriminator: Option<TokenStream>,
}

impl Parse for IxAttr {
fn parse(input: ParseStream) -> ParseResult<Self> {
let mut attr = Self::default();
let args = input.parse_terminated::<_, Comma>(AttrArg::parse)?;
for arg in args {
match arg.name.to_string().as_str() {
"discriminator" => {
let value = match &arg.value {
// Allow `discriminator = 42`
Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
// Allow `discriminator = [0, 1, 2, 3]`
Expr::Array(arr) => quote! { &#arr },
expr => expr.to_token_stream(),
};
attr.discriminator.replace(value)
}
_ => return Err(ParseError::new(arg.name.span(), "Invalid argument")),
};
}

Ok(attr)
}
}

struct AttrArg {
name: Ident,
#[allow(dead_code)]
eq_token: Token!(=),
value: Expr,
}

impl Parse for AttrArg {
fn parse(input: ParseStream) -> ParseResult<Self> {
Ok(Self {
name: input.parse()?,
eq_token: input.parse()?,
value: input.parse()?,
})
}
}

#[derive(Debug)]
Expand Down
16 changes: 15 additions & 1 deletion lang/syn/src/parser/program/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::parser::docs;
use crate::parser::program::ctx_accounts_ident;
use crate::parser::spl_interface;
use crate::{FallbackFn, Ix, IxArg, IxReturn};
use crate::{FallbackFn, Ix, IxArg, IxAttr, IxReturn};
use syn::parse::{Error as ParseError, Result as ParseResult};
use syn::spanned::Spanned;

Expand All @@ -25,6 +25,7 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
})
.map(|method: &syn::ItemFn| {
let (ctx, args) = parse_args(method)?;
let ix_attr = parse_ix_attr(&method.attrs)?;
let interface_discriminator = spl_interface::parse(&method.attrs);
let docs = docs::parse(&method.attrs);
let returns = parse_return(method)?;
Expand All @@ -37,6 +38,7 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
anchor_ident,
returns,
interface_discriminator,
ix_attr,
})
})
.collect::<ParseResult<Vec<Ix>>>()?;
Expand Down Expand Up @@ -71,6 +73,18 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
Ok((ixs, fallback_fn))
}

/// Parse `#[instruction]` attribute proc-macro.
fn parse_ix_attr(attrs: &[syn::Attribute]) -> ParseResult<Option<IxAttr>> {
attrs
.iter()
.find(|attr| match attr.path.segments.last() {
Some(seg) => seg.ident == "instruction",
_ => false,
})
.map(|attr| attr.parse_args())
.transpose()
}

pub fn parse_args(method: &syn::ItemFn) -> ParseResult<(IxArg, Vec<IxArg>)> {
let mut args: Vec<IxArg> = method
.sig
Expand Down
9 changes: 9 additions & 0 deletions tests/custom-discriminator/Anchor.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[programs.localnet]
custom_discriminator = "CustomDiscriminator111111111111111111111111"

[provider]
cluster = "localnet"
wallet = "~/.config/solana/id.json"

[scripts]
test = "yarn run ts-mocha -p ./tsconfig.json -t 1000000 tests/**/*.ts"
14 changes: 14 additions & 0 deletions tests/custom-discriminator/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[workspace]
members = [
"programs/*"
]
resolver = "2"

[profile.release]
overflow-checks = true
lto = "fat"
codegen-units = 1
[profile.release.build-override]
opt-level = 3
incremental = false
codegen-units = 1
16 changes: 16 additions & 0 deletions tests/custom-discriminator/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"name": "custom-discriminator",
"version": "0.30.1",
"license": "(MIT OR Apache-2.0)",
"homepage": "https://github.com/coral-xyz/anchor#readme",
"bugs": {
"url": "https://github.com/coral-xyz/anchor/issues"
},
"repository": {
"type": "git",
"url": "https://github.com/coral-xyz/anchor.git"
},
"engines": {
"node": ">=17"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "custom-discriminator"
version = "0.1.0"
description = "Created with Anchor"
edition = "2021"

[lib]
crate-type = ["cdylib", "lib"]
name = "custom_discriminator"

[features]
no-entrypoint = []
no-idl = []
cpi = ["no-entrypoint"]
default = []
idl-build = ["anchor-lang/idl-build"]

[dependencies]
anchor-lang = { path = "../../../../lang" }
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[target.bpfel-unknown-unknown.dependencies.std]
features = []
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use anchor_lang::prelude::*;

declare_id!("CustomDiscriminator111111111111111111111111");

const CONST_DISC: &'static [u8] = &[55, 66, 77, 88];

const fn get_disc(input: &str) -> &'static [u8] {
match input.as_bytes() {
b"wow" => &[4 + 5, 55 / 5],
_ => unimplemented!(),
}
}

#[program]
pub mod custom_discriminator {
use super::*;

#[instruction(discriminator = 0)]
pub fn int(_ctx: Context<DefaultIx>) -> Result<()> {
Ok(())
}

#[instruction(discriminator = [1, 2, 3, 4])]
pub fn array(_ctx: Context<DefaultIx>) -> Result<()> {
Ok(())
}

#[instruction(discriminator = b"hi")]
pub fn byte_str(_ctx: Context<DefaultIx>) -> Result<()> {
Ok(())
}

#[instruction(discriminator = CONST_DISC)]
pub fn constant(_ctx: Context<DefaultIx>) -> Result<()> {
Ok(())
}

#[instruction(discriminator = get_disc("wow"))]
pub fn const_fn(_ctx: Context<DefaultIx>) -> Result<()> {
Ok(())
}
}

#[derive(Accounts)]
pub struct DefaultIx<'info> {
pub signer: Signer<'info>,
}
31 changes: 31 additions & 0 deletions tests/custom-discriminator/tests/custom-discriminator.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import * as anchor from "@coral-xyz/anchor";
import assert from "assert";

import type { CustomDiscriminator } from "../target/types/custom_discriminator";

describe("custom-discriminator", () => {
anchor.setProvider(anchor.AnchorProvider.env());
const program: anchor.Program<CustomDiscriminator> =
anchor.workspace.customDiscriminator;

describe("Can use custom instruction discriminators", () => {
const testCommon = async (ixName: keyof typeof program["methods"]) => {
const tx = await program.methods[ixName]().transaction();

// Verify discriminator
const ix = program.idl.instructions.find((ix) => ix.name === ixName)!;
assert(ix.discriminator.length < 8);
const data = tx.instructions[0].data;
assert(data.equals(Buffer.from(ix.discriminator)));

// Verify tx runs
await program.provider.sendAndConfirm!(tx);
};

it("Integer", () => testCommon("int"));
it("Array", () => testCommon("array"));
it("Byte string", () => testCommon("byteStr"));
it("Constant", () => testCommon("constant"));
it("Const Fn", () => testCommon("constFn"));
});
});
Loading

0 comments on commit 3f945f6

Please sign in to comment.