Skip to content

Commit

Permalink
Generate non-differentiable type adjoints for custom derivative calls
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Apr 15, 2024
1 parent c4d0fbf commit ad770a7
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 26 deletions.
5 changes: 3 additions & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,9 @@ namespace clad {
/// null otherwise.
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv = true, bool namespaceShouldExist = true);
clang::Scope* S, const clang::FunctionDecl* originalFD,
bool forCustomDerv = true, bool namespaceShouldExist = true,
llvm::SmallVectorImpl<clang::Stmt*>* block = nullptr);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
};
Expand Down
3 changes: 1 addition & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
Expr* callDiff = BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
customPushforward, customDerivativeArgs, getCurrentScope(), FD);

// Check if it is a recursive call.
if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) {
Expand Down
14 changes: 9 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1670,8 +1670,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pushforwardCallArgs.push_back(ConstantFolder::synthesizeLiteral(
DerivedCallArgs.front()->getType(), m_Context, 1));
OverloadedDerivedFn = BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
customPushforward, pushforwardCallArgs, getCurrentScope(), FD);
if (OverloadedDerivedFn)
asGrad = false;
}
Expand Down Expand Up @@ -1766,8 +1765,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
OverloadedDerivedFn = BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
customPullback, pullbackCallArgs, getCurrentScope(), FD,
/*forCustomDerv=*/true,
/*namespaceShouldExist=*/true,
/*block=*/&PreCallStmts);
if (baseDiff.getExpr())
pullbackCallArgs.erase(pullbackCallArgs.begin());
}
Expand Down Expand Up @@ -2042,8 +2043,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
NumDiffArgs.push_back(args[i]);
}
std::string Name = "central_difference";
const FunctionDecl* FD = nullptr;
if (auto* DRE = dyn_cast<DeclRefExpr>(targetFuncCall->IgnoreImplicit()))
FD = dyn_cast<FunctionDecl>(DRE->getDecl());
return BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
Name, NumDiffArgs, getCurrentScope(), /*OriginalFD=*/FD,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
}
Expand Down
48 changes: 40 additions & 8 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,8 +734,11 @@ namespace clad {
NumDiffArgs.insert(NumDiffArgs.end(), args.begin(), args.begin() + numArgs);
// Return the found overload.
std::string Name = "forward_central_difference";
const FunctionDecl* FD = nullptr;
if (auto* DRE = dyn_cast<DeclRefExpr>(targetFuncCall->IgnoreImplicit()))
FD = dyn_cast<FunctionDecl>(DRE->getDecl());
return BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
Name, NumDiffArgs, getCurrentScope(), /*OriginalFD=*/FD,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
}
Expand Down Expand Up @@ -818,8 +821,12 @@ namespace clad {

Expr* VisitorBase::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {
clang::Scope* S, const clang::FunctionDecl* originalFD,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/,
llvm::SmallVectorImpl<Stmt*>* block /*=nullptr*/) {
DeclContext* originalFnDC = nullptr;
if (originalFD)
originalFnDC = const_cast<DeclContext*>(originalFD->getDeclContext());
NamespaceDecl* NSD = nullptr;
std::string namespaceID;
if (forCustomDerv) {
Expand All @@ -841,7 +848,7 @@ namespace clad {
NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist);
if (!forCustomDerv && !NSD) {
diag(DiagnosticsEngine::Warning, noLoc,
"Numerical differentiation is diabled using the "
"Numerical differentiation is disabled using the "
"-DCLAD_NO_NUM_DIFF "
"flag, this means that every try to numerically differentiate a "
"function will fail! Remove the flag to revert to default "
Expand Down Expand Up @@ -889,15 +896,40 @@ namespace clad {
Expr* UnresolvedLookup =
m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get();

auto MARargs = llvm::MutableArrayRef<Expr*>(CallArgs);

SourceLocation Loc;
llvm::SmallVector<Expr*, 16> ExtendedCallArgs(CallArgs.begin(),
CallArgs.end());
llvm::SmallVector<Stmt*, 16> DeclStmts;
// FIXME: for now, integer types are considered differentiable in the
// forward mode.
if (m_Mode != DiffMode::forward &&
m_Mode != DiffMode::vector_forward_mode &&
m_Mode != DiffMode::experimental_pushforward)
for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) {
QualType paramTy = originalFD->getParamDecl(i)->getType();
if (!utils::IsDifferentiableType(paramTy)) {
QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema);
VarDecl* argDecl = BuildVarDecl(argTy, "_r", getZeroInit(argTy));
Expr* arg = BuildDeclRef(argDecl);
if (!utils::isArrayOrPointerType(argTy))
arg = BuildOp(UO_AddrOf, arg);
ExtendedCallArgs.insert(ExtendedCallArgs.begin() + e + i + 1, arg);
DeclStmts.push_back(BuildDeclStmt(argDecl));
}
}
auto MARargs = llvm::MutableArrayRef<Expr*>(ExtendedCallArgs);

if (noOverloadExists(UnresolvedLookup, MARargs))
return nullptr;

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get();
m_Sema.ActOnCallExpr(S, UnresolvedLookup, noLoc, MARargs, noLoc)
.get();
if (!DeclStmts.empty()) {
if (!block)
block = &getCurrentBlock();
for (Stmt* decl : DeclStmts)
block->push_back(decl);
}
}
return OverloadedFn;
}
Expand Down
11 changes: 4 additions & 7 deletions test/FirstDerivative/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,16 @@ double f10(float x, int y) {
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

void f10_grad(float x, int y, float *_d_x, int *_d_y);
void f10_grad_0(float x, int y, float *_d_x);

// CHECK: void f10_grad(float x, int y, float *_d_x, int *_d_y) {
// CHECK: void f10_grad_0(float x, int y, float *_d_x) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = 0;
// CHECK-NEXT: int _r1 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &_r0, &_r1);
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: *_d_y += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down Expand Up @@ -222,7 +221,6 @@ double f12(double a, double b) { return std::fma(a, b, b); }
int main () { //expected-no-diagnostics
float f_result[2];
double d_result[2];
int i_result[1];

auto f1_darg0 = clad::differentiate(f1, 0);
printf("Result is = %f\n", f1_darg0.execute(60)); // CHECK-EXEC: Result is = -0.952413
Expand Down Expand Up @@ -276,10 +274,9 @@ int main () { //expected-no-diagnostics
printf("Result is = %f\n", f10_darg0.execute(3, 4)); //CHECK-EXEC: Result is = 108.000000

f_result[0] = f_result[1] = 0;
i_result[0] = 0;
clad::gradient(f10);
f10_grad(3, 4, &f_result[0], &i_result[0]);
printf("Result is = {%f, %d}\n", f_result[0], i_result[0]); //CHECK-EXEC: Result is = {108.000000, 88}
f10_grad_0(3, 4, &f_result[0]);
printf("Result is = {%f}\n", f_result[0]); //CHECK-EXEC: Result is = {108.000000}

INIT_GRADIENT(f11);

Expand Down
4 changes: 2 additions & 2 deletions test/NumericalDiff/NoNumDiff.C
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

double func(double x) { return std::tanh(x); }

//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour.
//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour.
//CHECK: warning: Numerical differentiation is disabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour.
//CHECK: warning: Numerical differentiation is disabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour.
//CHECK: double func_darg0(double x) {
//CHECK-NEXT: double _d_x = 1;
//CHECK-NEXT: return 0;
Expand Down

0 comments on commit ad770a7

Please sign in to comment.