From b2e0204c7d7a784d9f8bc1cbf9cabbc668fa4599 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Mon, 16 Dec 2024 19:20:31 +0000 Subject: [PATCH] Simplify handling of diff request options. NFC --- lib/Differentiator/DiffPlanner.cpp | 306 ++++++++++++++++------------- 1 file changed, 164 insertions(+), 142 deletions(-) diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 1f1fe761f..d2c39c1d9 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -654,6 +654,135 @@ namespace clad { return found != m_ActivityRunInfo.ToBeRecorded.end(); } + ///\returns true on error. + static bool ProcessInvocationArgs(Sema& S, SourceLocation endLoc, + const RequestOptions& ReqOpts, + const FunctionDecl* FD, + DiffRequest& request) { + const AnnotateAttr* A = FD->getAttr(); + if (A->getAnnotation().equals("E")) { + // Error estimation has no options yet. + request.Mode = DiffMode::error_estimation; + return false; + } + + if (A->getAnnotation().equals("D")) + request.Mode = DiffMode::forward; + else if (A->getAnnotation().equals("H")) + request.Mode = DiffMode::hessian; + else if (A->getAnnotation().equals("J")) + request.Mode = DiffMode::jacobian; + else if (A->getAnnotation().equals("G")) + request.Mode = DiffMode::reverse; + else { + utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, "Unknown mode '%0'", + A->getAnnotation()); + return true; + } + + request.EnableTBRAnalysis = ReqOpts.EnableTBRAnalysis; + request.EnableVariedAnalysis = ReqOpts.EnableVariedAnalysis; + + const TemplateArgumentList* TAL = FD->getTemplateSpecializationArgs(); + assert(TAL && "Call must have specialization args!"); + + // bitmask_opts is a template pack of unsigned integers, so we need to + // do bitwise or of all the values to get the final value. + unsigned bitmasked_opts_value = 0; + const auto template_arg = TAL->get(0); + if (template_arg.getKind() == TemplateArgument::Pack) + for (const auto& arg : TAL->get(0).pack_elements()) + bitmasked_opts_value |= arg.getAsIntegral().getExtValue(); + + bool enable_tbr_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr); + bool disable_tbr_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr); + bool enable_va_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::enable_va); + bool disable_va_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::disable_va); + + // Sanity checks. + if (enable_tbr_in_req && disable_tbr_in_req) { + utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, + "Both enable and disable TBR options are specified."); + return true; + } + if (enable_va_in_req && disable_va_in_req) { + utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, + "Both enable and disable VA options are specified."); + return true; + } + if (enable_tbr_in_req && request.Mode == DiffMode::forward) { + utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, + "TBR analysis is not meant for forward mode AD."); + return true; + } + + // reverse vector mode is not yet supported. + if (request.Mode == DiffMode::reverse && + clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) { + utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, + "Reverse vector mode is not yet supported."); + return true; + } + + // Override the default value of TBR analysis. + if (enable_tbr_in_req || disable_tbr_in_req) + request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req; + + // Override the default value of TBR analysis. + if (enable_va_in_req || disable_va_in_req) + request.EnableVariedAnalysis = enable_va_in_req && !disable_va_in_req; + + // Check for clad::hessian. + if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) { + if (request.Mode == DiffMode::hessian) { + request.Mode = DiffMode::hessian_diagonal; + return false; + } + utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, + "Diagonal only option is only valid for Hessian mode."); + return true; + } + + if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) + request.use_enzyme = true; + + if (request.Mode == DiffMode::forward) { + // Check for clad::differentiate. + if (unsigned order = clad::GetDerivativeOrder(bitmasked_opts_value)) + request.RequestedDerivativeOrder = order; + + // Check for clad::differentiate. + if (clad::HasOption(bitmasked_opts_value, clad::opts::immediate_mode)) + request.ImmediateMode = true; + + // Check for clad::differentiate. + if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) { + request.Mode = DiffMode::vector_forward_mode; + + // Currently only first order derivative is supported. + if (request.RequestedDerivativeOrder != 1) { + utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, + "Only first order derivative is supported for now " + "in vector forward mode."); + return true; + } + + // We don't yet support enzyme with vector mode. + if (request.use_enzyme) { + utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, + "Enzyme's vector mode is not yet supported."); + return true; + } + } + } + + return false; + } + bool DiffCollector::VisitCallExpr(CallExpr* E) { // Check if we should look into this. // FIXME: Generated code does not usually have valid source locations. @@ -666,153 +795,46 @@ namespace clad { FunctionDecl* FD = E->getDirectCallee(); if (!FD) return true; + // We need to find our 'special' diff annotated such: // clad::differentiate(...) __attribute__((annotate("D"))) // TODO: why not check for its name? clad::differentiate/gradient? const AnnotateAttr* A = FD->getAttr(); - if (A && - (A->getAnnotation().equals("D") || A->getAnnotation().equals("G") || - A->getAnnotation().equals("H") || A->getAnnotation().equals("J") || - A->getAnnotation().equals("E"))) { - // A call to clad::differentiate or clad::gradient was found. - DeclRefExpr* DRE = getArgFunction(E, m_Sema); - if (!DRE) - return true; - DiffRequest request{}; - - // bitmask_opts is a template pack of unsigned integers, so we need to - // do bitwise or of all the values to get the final value. - unsigned bitmasked_opts_value = 0; - bool enable_tbr_in_req = false; - bool disable_tbr_in_req = false; - bool enable_va_in_req = false; - bool disable_va_in_req = false; - if (!A->getAnnotation().equals("E") && - FD->getTemplateSpecializationArgs()) { - const auto template_arg = FD->getTemplateSpecializationArgs()->get(0); - if (template_arg.getKind() == TemplateArgument::Pack) - for (const auto& arg : - FD->getTemplateSpecializationArgs()->get(0).pack_elements()) - bitmasked_opts_value |= arg.getAsIntegral().getExtValue(); - else - bitmasked_opts_value = template_arg.getAsIntegral().getExtValue(); - - // Set option for TBR analysis. - enable_tbr_in_req = - clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr); - disable_tbr_in_req = - clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr); - // Set option for Activity analysis. - enable_va_in_req = - clad::HasOption(bitmasked_opts_value, clad::opts::enable_va); - disable_va_in_req = - clad::HasOption(bitmasked_opts_value, clad::opts::disable_va); - if (enable_tbr_in_req && disable_tbr_in_req) { - utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, - "Both enable and disable TBR options are specified."); - return true; - } - if (enable_va_in_req && disable_va_in_req) { - utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, - "Both enable and disable VA options are specified."); - return true; - } - if (enable_tbr_in_req || disable_tbr_in_req) { - // override the default value of TBR analysis. - request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req; - } else { - request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis; - } - if (enable_va_in_req || disable_va_in_req) { - // override the default value of TBR analysis. - request.EnableVariedAnalysis = enable_va_in_req && !disable_va_in_req; - } else { - request.EnableVariedAnalysis = m_Options.EnableVariedAnalysis; - } - if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) { - if (!A->getAnnotation().equals("H")) { - utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, - "Diagonal only option is only valid for Hessian " - "mode."); - return true; - } - } - } + if (!A) + return true; + if (!A->getAnnotation().equals("D") && !A->getAnnotation().equals("G") && + !A->getAnnotation().equals("H") && !A->getAnnotation().equals("J") && + !A->getAnnotation().equals("E")) + return true; - if (A->getAnnotation().equals("D")) { - request.Mode = DiffMode::forward; - unsigned derivative_order = - clad::GetDerivativeOrder(bitmasked_opts_value); - if (derivative_order == 0) { - derivative_order = 1; // default to first order derivative. - } - request.RequestedDerivativeOrder = derivative_order; - if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) - request.use_enzyme = true; - if (clad::HasOption(bitmasked_opts_value, clad::opts::immediate_mode)) - request.ImmediateMode = true; - if (enable_tbr_in_req) { - utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, - "TBR analysis is not meant for forward mode AD."); - return true; - } - if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) { - request.Mode = DiffMode::vector_forward_mode; - - // currently only first order derivative is supported. - if (derivative_order != 1) { - utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, - "Only first order derivative is supported for now " - "in vector forward mode."); - return true; - } - // we don't yet support enzyme with vector mode. - if (request.use_enzyme) { - utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, - "Enzyme's vector mode is not yet supported."); - return true; - } - } - } else if (A->getAnnotation().equals("H")) { - if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) - request.Mode = DiffMode::hessian_diagonal; - else - request.Mode = DiffMode::hessian; - } else if (A->getAnnotation().equals("J")) { - request.Mode = DiffMode::jacobian; - } else if (A->getAnnotation().equals("G")) { - request.Mode = DiffMode::reverse; - if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) - request.use_enzyme = true; - // reverse vector mode is not yet supported. - if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) { - utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, - "Reverse vector mode is not yet supported."); - return true; - } - } else { - request.Mode = DiffMode::error_estimation; - } - request.CallContext = E; - request.CallUpdateRequired = true; - request.VerboseDiags = true; - request.Args = E->getArg(1); - auto derivedFD = cast(DRE->getDecl()); - request.Function = derivedFD; - request.BaseFunctionName = utils::ComputeEffectiveFnName(request.Function); - - if (isCallOperator(m_Sema.getASTContext(), request.Function)) { - request.Functor = cast(request.Function)->getParent(); - } - // FIXME: add support for nested calls to clad::differentiate/gradient - // inside differentiated functions - assert(!m_TopMostFD && - "nested clad::differentiate/gradient are not yet supported"); - llvm::SaveAndRestore saveTopMost = m_TopMostFD; - m_TopMostFD = FD; - TraverseDecl(derivedFD); - m_DiffRequestGraph.addNode(request, /*isSource=*/true); - } + // A call to clad::differentiate or clad::gradient was found. + DeclRefExpr* DRE = getArgFunction(E, m_Sema); + if (!DRE) + return true; + + DiffRequest request; + + if (ProcessInvocationArgs(m_Sema, endLoc, m_Options, FD, request)) + return true; + + request.CallContext = E; + request.CallUpdateRequired = true; + request.VerboseDiags = true; + request.Args = E->getArg(1); + auto* derivedFD = cast(DRE->getDecl()); + request.Function = derivedFD; + request.BaseFunctionName = utils::ComputeEffectiveFnName(request.Function); + + if (isCallOperator(m_Sema.getASTContext(), request.Function)) + request.Functor = cast(request.Function)->getParent(); + // FIXME: add support for nested calls to clad::differentiate/gradient + // inside differentiated functions + assert(!m_TopMostFD && + "nested clad::differentiate/gradient are not yet supported"); + llvm::SaveAndRestore saveTopMost = m_TopMostFD; + m_TopMostFD = FD; + TraverseDecl(derivedFD); + m_DiffRequestGraph.addNode(request, /*isSource=*/true); /*else if (m_TopMostFD) { // If another function is called inside differentiated function, // this will be handled by Forward/ReverseModeVisitor::Derive.