From 3585ac5937d03ad14abf4d142f1cb3d7590cf471 Mon Sep 17 00:00:00 2001 From: Victor Martinez <49537445+JasterV@users.noreply.github.com> Date: Mon, 26 Feb 2024 15:56:59 +0100 Subject: [PATCH] refactor: clean up the macro implementation --- test-context-macros/Cargo.toml | 1 + test-context-macros/src/lib.rs | 79 +++++++++++++++++++--------------- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/test-context-macros/Cargo.toml b/test-context-macros/Cargo.toml index 6eb7978..6f25c35 100644 --- a/test-context-macros/Cargo.toml +++ b/test-context-macros/Cargo.toml @@ -16,3 +16,4 @@ proc-macro = true [dependencies] quote = "1.0.3" syn = { version = "^2", features = ["full"] } +proc-macro2 = "1.0.78" diff --git a/test-context-macros/src/lib.rs b/test-context-macros/src/lib.rs index 36492df..051fa91 100644 --- a/test-context-macros/src/lib.rs +++ b/test-context-macros/src/lib.rs @@ -1,5 +1,6 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; +use syn::Ident; /// Macro to use on tests to add the setup/teardown functionality of your context. /// @@ -36,42 +37,9 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { let wrapped_name = format_ident!("__test_context_wrapped_{}", name); let outer_body = if is_async { - quote! { - { - use test_context::futures::FutureExt; - let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await; - let wrapped_ctx = &mut ctx; - let result = async move { - std::panic::AssertUnwindSafe( - #wrapped_name(wrapped_ctx) - ).catch_unwind().await - }.await; - <#context_type as test_context::AsyncTestContext>::teardown(ctx).await; - match result { - Ok(returned_value) => returned_value, - Err(err) => { - std::panic::resume_unwind(err); - } - } - } - } + async_implementation(context_type, &wrapped_name) } else { - quote! { - { - let mut ctx = <#context_type as test_context::TestContext>::setup(); - let mut wrapper = std::panic::AssertUnwindSafe(&mut ctx); - let result = std::panic::catch_unwind(move || { - #wrapped_name(*wrapper) - }); - <#context_type as test_context::TestContext>::teardown(ctx); - match result { - Ok(returned_value) => returned_value, - Err(err) => { - std::panic::resume_unwind(err); - } - } - } - } + sync_implementation(context_type, &wrapped_name) }; let async_tag = if is_async { @@ -88,3 +56,44 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream { }; result.into() } + +fn async_implementation(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream { + quote! { + { + use test_context::futures::FutureExt; + let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await; + let wrapped_ctx = &mut ctx; + let result = async move { + std::panic::AssertUnwindSafe( + #wrapped_name(wrapped_ctx) + ).catch_unwind().await + }.await; + <#context_type as test_context::AsyncTestContext>::teardown(ctx).await; + match result { + Ok(returned_value) => returned_value, + Err(err) => { + std::panic::resume_unwind(err); + } + } + } + } +} + +fn sync_implementation(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream { + quote! { + { + let mut ctx = <#context_type as test_context::TestContext>::setup(); + let mut wrapper = std::panic::AssertUnwindSafe(&mut ctx); + let result = std::panic::catch_unwind(move || { + #wrapped_name(*wrapper) + }); + <#context_type as test_context::TestContext>::teardown(ctx); + match result { + Ok(returned_value) => returned_value, + Err(err) => { + std::panic::resume_unwind(err); + } + } + } + } +}