Skip to content

Commit

Permalink
Use a single point to process non-differentiable functions
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 3, 2024
1 parent 48b76c4 commit cba208d
Showing 1 changed file with 21 additions and 44 deletions.
65 changes: 21 additions & 44 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1467,27 +1467,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (const auto* KCE = dyn_cast<CUDAKernelCallExpr>(CE))
CUDAExecConfig = Clone(KCE->getConfig());

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
llvm::SmallVector<Expr*, 4> ClonedArgs;
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i)
ClonedArgs.push_back(Clone(CE->getArg(i)));

SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema);
Expr* Call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc, CUDAExecConfig)
.get();
// Creating a zero derivative
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);

// Returning the function call and zero derivative
return StmtDiff(Call, zero);
}

// begin and end are common enough to have a more efficient and nice-looking
// special case. Instead of _forw and a useless _pullback functions, we can
// express the result in terms of the same std::begin / std::end. Note:
Expand All @@ -1512,13 +1491,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

auto NArgs = FD->getNumParams();
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
// contribute to gradient.
if ((NArgs == 0U) && !isa<CXXMemberCallExpr>(CE) &&
!isa<CXXOperatorCallExpr>(CE))
return StmtDiff(Clone(CE));

SourceLocation Loc = CE->getExprLoc();

// Stores the call arguments for the function to be derived
Expand Down Expand Up @@ -1603,11 +1575,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff();
}

bool nonDiff = clad::utils::hasNonDifferentiableAttribute(CE);

// If the result does not depend on the result of the call, just clone
// the call and visit arguments (since they may contain side-effects like
// f(x = y))
// If the callee function takes arguments by reference then it can affect
// derivatives even if there is no `dfdx()` and thus we should call the
// derived function. In the case of member functions, `implicit`
// this object is always passed by reference.
if (!nonDiff && !dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE))
nonDiff = true;

// If all arguments are constant literals, then this does not contribute to
// the gradient.
// FIXME: revert this when this is integrated in the activity analysis pass.
if (!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
bool allArgsAreConstantLiterals = true;
if (!nonDiff && !isa<CXXMemberCallExpr>(CE) &&
!isa<CXXOperatorCallExpr>(CE)) {
bool allArgsAreConstant = true;
for (const Expr* arg : CE->arguments()) {
// if it's of type MaterializeTemporaryExpr, then check its
// subexpression.
Expand All @@ -1634,25 +1619,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
} analyzer(m_DiffReq);
if (analyzer.isVariedE(arg)) {
allArgsAreConstantLiterals = false;
allArgsAreConstant = false;
break;
}
}
if (allArgsAreConstantLiterals)
return StmtDiff(Clone(CE), Clone(CE));
if (allArgsAreConstant)
nonDiff = true;
}

// If the result does not depend on the result of the call, just clone
// the call and visit arguments (since they may contain side-effects like
// f(x = y))
// If the callee function takes arguments by reference then it can affect
// derivatives even if there is no `dfdx()` and thus we should call the
// derived function. In the case of member functions, `implicit`
// this object is always passed by reference.
if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
if (nonDiff) {
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
StmtDiff ArgDiff = Visit(Arg);
CallArgs.push_back(ArgDiff.getExpr());
}
Expr* call =
Expand Down

0 comments on commit cba208d

Please sign in to comment.