Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify handling of diff request options. NFC #1174

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 164 additions & 142 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,135 @@
return found != m_ActivityRunInfo.ToBeRecorded.end();
}

///\returns true on error.
static bool ProcessInvocationArgs(Sema& S, SourceLocation endLoc,
const RequestOptions& ReqOpts,
const FunctionDecl* FD,
DiffRequest& request) {
const AnnotateAttr* A = FD->getAttr<AnnotateAttr>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "clang::AnnotateAttr" is directly included [misc-include-cleaner]

    const AnnotateAttr* A = FD->getAttr<AnnotateAttr>();
          ^

if (A->getAnnotation().equals("E")) {
// Error estimation has no options yet.
request.Mode = DiffMode::error_estimation;
return false;
}

if (A->getAnnotation().equals("D"))
request.Mode = DiffMode::forward;
else if (A->getAnnotation().equals("H"))
request.Mode = DiffMode::hessian;
else if (A->getAnnotation().equals("J"))
request.Mode = DiffMode::jacobian;
else if (A->getAnnotation().equals("G"))
request.Mode = DiffMode::reverse;
else {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, "Unknown mode '%0'",
A->getAnnotation());
return true;

Check warning on line 680 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L678-L680

Added lines #L678 - L680 were not covered by tests
}

request.EnableTBRAnalysis = ReqOpts.EnableTBRAnalysis;
request.EnableVariedAnalysis = ReqOpts.EnableVariedAnalysis;

const TemplateArgumentList* TAL = FD->getTemplateSpecializationArgs();
assert(TAL && "Call must have specialization args!");

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
const auto template_arg = TAL->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg : TAL->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();

bool enable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr);
bool disable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr);
bool enable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_va);
bool disable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_va);

// Sanity checks.
if (enable_tbr_in_req && disable_tbr_in_req) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Both enable and disable TBR options are specified.");
return true;
}
if (enable_va_in_req && disable_va_in_req) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Both enable and disable VA options are specified.");
return true;
}
if (enable_tbr_in_req && request.Mode == DiffMode::forward) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"TBR analysis is not meant for forward mode AD.");
return true;
}

// reverse vector mode is not yet supported.
if (request.Mode == DiffMode::reverse &&
clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Reverse vector mode is not yet supported.");
return true;
}

// Override the default value of TBR analysis.
if (enable_tbr_in_req || disable_tbr_in_req)
request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req;

// Override the default value of TBR analysis.
if (enable_va_in_req || disable_va_in_req)
request.EnableVariedAnalysis = enable_va_in_req && !disable_va_in_req;

// Check for clad::hessian<diagonal_only>.
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) {
if (request.Mode == DiffMode::hessian) {
request.Mode = DiffMode::hessian_diagonal;
return false;
}
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Diagonal only option is only valid for Hessian mode.");
return true;
}

if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;

if (request.Mode == DiffMode::forward) {
// Check for clad::differentiate<N>.
if (unsigned order = clad::GetDerivativeOrder(bitmasked_opts_value))
request.RequestedDerivativeOrder = order;

// Check for clad::differentiate<immediate_mode>.
if (clad::HasOption(bitmasked_opts_value, clad::opts::immediate_mode))
request.ImmediateMode = true;

// Check for clad::differentiate<vector_mode>.
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
request.Mode = DiffMode::vector_forward_mode;

// Currently only first order derivative is supported.
if (request.RequestedDerivativeOrder != 1) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Only first order derivative is supported for now "
"in vector forward mode.");
return true;
}

// We don't yet support enzyme with vector mode.
if (request.use_enzyme) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Enzyme's vector mode is not yet supported.");
return true;
}
}
}

return false;
}

bool DiffCollector::VisitCallExpr(CallExpr* E) {
// Check if we should look into this.
// FIXME: Generated code does not usually have valid source locations.
Expand All @@ -666,153 +795,46 @@
FunctionDecl* FD = E->getDirectCallee();
if (!FD)
return true;

// We need to find our 'special' diff annotated such:
// clad::differentiate(...) __attribute__((annotate("D")))
// TODO: why not check for its name? clad::differentiate/gradient?
const AnnotateAttr* A = FD->getAttr<AnnotateAttr>();
if (A &&
(A->getAnnotation().equals("D") || A->getAnnotation().equals("G") ||
A->getAnnotation().equals("H") || A->getAnnotation().equals("J") ||
A->getAnnotation().equals("E"))) {
// A call to clad::differentiate or clad::gradient was found.
DeclRefExpr* DRE = getArgFunction(E, m_Sema);
if (!DRE)
return true;
DiffRequest request{};

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
bool enable_tbr_in_req = false;
bool disable_tbr_in_req = false;
bool enable_va_in_req = false;
bool disable_va_in_req = false;
if (!A->getAnnotation().equals("E") &&
FD->getTemplateSpecializationArgs()) {
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();

// Set option for TBR analysis.
enable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr);
disable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr);
// Set option for Activity analysis.
enable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_va);
disable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_va);
if (enable_tbr_in_req && disable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable TBR options are specified.");
return true;
}
if (enable_va_in_req && disable_va_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable VA options are specified.");
return true;
}
if (enable_tbr_in_req || disable_tbr_in_req) {
// override the default value of TBR analysis.
request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req;
} else {
request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis;
}
if (enable_va_in_req || disable_va_in_req) {
// override the default value of TBR analysis.
request.EnableVariedAnalysis = enable_va_in_req && !disable_va_in_req;
} else {
request.EnableVariedAnalysis = m_Options.EnableVariedAnalysis;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) {
if (!A->getAnnotation().equals("H")) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Diagonal only option is only valid for Hessian "
"mode.");
return true;
}
}
}
if (!A)
return true;
if (!A->getAnnotation().equals("D") && !A->getAnnotation().equals("G") &&
!A->getAnnotation().equals("H") && !A->getAnnotation().equals("J") &&
!A->getAnnotation().equals("E"))
return true;

if (A->getAnnotation().equals("D")) {
request.Mode = DiffMode::forward;
unsigned derivative_order =
clad::GetDerivativeOrder(bitmasked_opts_value);
if (derivative_order == 0) {
derivative_order = 1; // default to first order derivative.
}
request.RequestedDerivativeOrder = derivative_order;
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
if (clad::HasOption(bitmasked_opts_value, clad::opts::immediate_mode))
request.ImmediateMode = true;
if (enable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"TBR analysis is not meant for forward mode AD.");
return true;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
request.Mode = DiffMode::vector_forward_mode;

// currently only first order derivative is supported.
if (derivative_order != 1) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Only first order derivative is supported for now "
"in vector forward mode.");
return true;
}
// we don't yet support enzyme with vector mode.
if (request.use_enzyme) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Enzyme's vector mode is not yet supported.");
return true;
}
}
} else if (A->getAnnotation().equals("H")) {
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only))
request.Mode = DiffMode::hessian_diagonal;
else
request.Mode = DiffMode::hessian;
} else if (A->getAnnotation().equals("J")) {
request.Mode = DiffMode::jacobian;
} else if (A->getAnnotation().equals("G")) {
request.Mode = DiffMode::reverse;
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
// reverse vector mode is not yet supported.
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Reverse vector mode is not yet supported.");
return true;
}
} else {
request.Mode = DiffMode::error_estimation;
}
request.CallContext = E;
request.CallUpdateRequired = true;
request.VerboseDiags = true;
request.Args = E->getArg(1);
auto derivedFD = cast<FunctionDecl>(DRE->getDecl());
request.Function = derivedFD;
request.BaseFunctionName = utils::ComputeEffectiveFnName(request.Function);

if (isCallOperator(m_Sema.getASTContext(), request.Function)) {
request.Functor = cast<CXXMethodDecl>(request.Function)->getParent();
}
// FIXME: add support for nested calls to clad::differentiate/gradient
// inside differentiated functions
assert(!m_TopMostFD &&
"nested clad::differentiate/gradient are not yet supported");
llvm::SaveAndRestore<const FunctionDecl*> saveTopMost = m_TopMostFD;
m_TopMostFD = FD;
TraverseDecl(derivedFD);
m_DiffRequestGraph.addNode(request, /*isSource=*/true);
}
// A call to clad::differentiate or clad::gradient was found.
DeclRefExpr* DRE = getArgFunction(E, m_Sema);
if (!DRE)
return true;

DiffRequest request;

if (ProcessInvocationArgs(m_Sema, endLoc, m_Options, FD, request))
return true;

request.CallContext = E;
request.CallUpdateRequired = true;
request.VerboseDiags = true;
request.Args = E->getArg(1);
auto* derivedFD = cast<FunctionDecl>(DRE->getDecl());
request.Function = derivedFD;
request.BaseFunctionName = utils::ComputeEffectiveFnName(request.Function);

if (isCallOperator(m_Sema.getASTContext(), request.Function))
request.Functor = cast<CXXMethodDecl>(request.Function)->getParent();
// FIXME: add support for nested calls to clad::differentiate/gradient
// inside differentiated functions
assert(!m_TopMostFD &&
"nested clad::differentiate/gradient are not yet supported");
llvm::SaveAndRestore<const FunctionDecl*> saveTopMost = m_TopMostFD;
m_TopMostFD = FD;
TraverseDecl(derivedFD);
m_DiffRequestGraph.addNode(request, /*isSource=*/true);
/*else if (m_TopMostFD) {
// If another function is called inside differentiated function,
// this will be handled by Forward/ReverseModeVisitor::Derive.
Expand Down
Loading