From 0fed4463417e90f958504512fc3dbcc60085223f Mon Sep 17 00:00:00 2001 From: Victor Martinez <49537445+JasterV@users.noreply.github.com> Date: Mon, 26 Feb 2024 15:56:59 +0100 Subject: [PATCH] refactor: clean up the macro implementation --- test-context-macros/Cargo.toml | 1 + test-context-macros/src/lib.rs | 94 ++++++++++++++++++---------------- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/test-context-macros/Cargo.toml b/test-context-macros/Cargo.toml index 6eb7978..6f25c35 100644 --- a/test-context-macros/Cargo.toml +++ b/test-context-macros/Cargo.toml @@ -16,3 +16,4 @@ proc-macro = true [dependencies] quote = "1.0.3" syn = { version = "^2", features = ["full"] } +proc-macro2 = "1.0.78" diff --git a/test-context-macros/src/lib.rs b/test-context-macros/src/lib.rs index 36492df..28ca235 100644 --- a/test-context-macros/src/lib.rs +++ b/test-context-macros/src/lib.rs @@ -1,5 +1,6 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; +use syn::Ident; /// Macro to use on tests to add the setup/teardown functionality of your context. /// @@ -25,53 +26,19 @@ use quote::{format_ident, quote}; pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { let context_type = syn::parse_macro_input!(attr as syn::Ident); let input = syn::parse_macro_input!(item as syn::ItemFn); - let ret = &input.sig.output; let name = &input.sig.ident; let arguments = &input.sig.inputs; - let inner_body = &input.block; + let body = &input.block; let attrs = &input.attrs; let is_async = input.sig.asyncness.is_some(); let wrapped_name = format_ident!("__test_context_wrapped_{}", name); - let outer_body = if is_async { - quote! { - { - use test_context::futures::FutureExt; - let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await; - let wrapped_ctx = &mut ctx; - let result = async move { - std::panic::AssertUnwindSafe( - #wrapped_name(wrapped_ctx) - ).catch_unwind().await - }.await; - <#context_type as test_context::AsyncTestContext>::teardown(ctx).await; - match result { - Ok(returned_value) => returned_value, - Err(err) => { - std::panic::resume_unwind(err); - } - } - } - } + let wrapper_body = if is_async { + async_wrapper_body(context_type, &wrapped_name) } else { - quote! { - { - let mut ctx = <#context_type as test_context::TestContext>::setup(); - let mut wrapper = std::panic::AssertUnwindSafe(&mut ctx); - let result = std::panic::catch_unwind(move || { - #wrapped_name(*wrapper) - }); - <#context_type as test_context::TestContext>::teardown(ctx); - match result { - Ok(returned_value) => returned_value, - Err(err) => { - std::panic::resume_unwind(err); - } - } - } - } + sync_wrapper_body(context_type, &wrapped_name) }; let async_tag = if is_async { @@ -80,11 +47,52 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { quote! {} }; - let result = quote! { + quote! { #(#attrs)* - #async_tag fn #name() #ret #outer_body + #async_tag fn #name() #ret #wrapper_body - #async_tag fn #wrapped_name(#arguments) #ret #inner_body - }; - result.into() + #async_tag fn #wrapped_name(#arguments) #ret #body + } + .into() +} + +fn async_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream { + quote! { + { + use test_context::futures::FutureExt; + let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await; + let wrapped_ctx = &mut ctx; + let result = async move { + std::panic::AssertUnwindSafe( + #wrapped_name(wrapped_ctx) + ).catch_unwind().await + }.await; + <#context_type as test_context::AsyncTestContext>::teardown(ctx).await; + match result { + Ok(returned_value) => returned_value, + Err(err) => { + std::panic::resume_unwind(err); + } + } + } + } +} + +fn sync_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream { + quote! { + { + let mut ctx = <#context_type as test_context::TestContext>::setup(); + let mut wrapper = std::panic::AssertUnwindSafe(&mut ctx); + let result = std::panic::catch_unwind(move || { + #wrapped_name(*wrapper) + }); + <#context_type as test_context::TestContext>::teardown(ctx); + match result { + Ok(returned_value) => returned_value, + Err(err) => { + std::panic::resume_unwind(err); + } + } + } + } }