diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 2116d5ea0..d2b74592b 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -65,6 +65,9 @@ struct DiffRequest { /// A flag to enable TBR analysis during reverse-mode differentiation. bool EnableTBRAnalysis = false; bool EnableVariedAnalysis = false; + /// A flag specifying whether this differentiation is to be used + /// in immediate contexts. + bool ImmediateMode = false; /// Puts the derived function and its code in the diff call void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD, clang::Sema& SemaRef); diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index f926488e2..8d4ba7574 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -747,6 +747,8 @@ namespace clad { request.RequestedDerivativeOrder = derivative_order; if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) request.use_enzyme = true; + if (clad::HasOption(bitmasked_opts_value, clad::opts::immediate_mode)) + request.ImmediateMode = true; if (enable_tbr_in_req) { utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, "TBR analysis is not meant for forward mode AD."); diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index b89170528..7fb6ec5b3 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -139,8 +139,9 @@ namespace clad { opts); for (DiffRequest& request : m_DiffRequestGraph.getNodes()) { - if (!request.Function->isImmediateFunction() && - !request.Function->isConstexpr()) + if (!request.ImmediateMode || + (!request.Function->isImmediateFunction() && + !request.Function->isConstexpr())) continue; m_DiffRequestGraph.setCurrentProcessingNode(request);