Skip to content

Commit

Permalink
Respect shadow declarations when writing propagators.
Browse files Browse the repository at this point in the history
In cases where the public declaration is introduced with using declaration
pointing to an internal namespace with the implementation details, we should
put the propagator function in the namespace of the public function and not the
implementation. That would allow users to position their pullbacks in the same
namespace structure as the used functions.
  • Loading branch information
vgvassilev committed Dec 14, 2024
1 parent cebc426 commit 27958c3
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 35 deletions.
6 changes: 4 additions & 2 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ namespace clad {
/// overload to be found.
/// \param[in] CallArgs The call args to be used to resolve to the
/// correct overload.
/// \param[in] callSite - The call expression which triggers the custom
/// derivative call.
/// \param[in] forCustomDerv A flag to keep track of which
/// namespace we should look in for the overloads.
/// \param[in] namespaceShouldExist A flag to enforce assertion failure
Expand All @@ -117,8 +119,8 @@ namespace clad {
/// null otherwise.
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, const clang::DeclContext* originalFnDC,
bool forCustomDerv = true, bool namespaceShouldExist = true,
clang::Scope* S, const clang::Expr* callSite, bool forCustomDerv = true,
bool namespaceShouldExist = true,
clang::Expr* CUDAExecConfig = nullptr);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace clad {

/// Tries to find and build call to user-provided `_forw` function.
clang::Expr* BuildCallToCustomForwPassFn(
const clang::FunctionDecl* FD, llvm::ArrayRef<clang::Expr*> primalArgs,
const clang::Expr* callSite, llvm::ArrayRef<clang::Expr*> primalArgs,
llvm::ArrayRef<clang::Expr*> derivedArgs, clang::Expr* baseExpr);

public:
Expand Down
9 changes: 3 additions & 6 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1226,17 +1226,15 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
FD->getDeclContext());
customPushforward, customDerivativeArgs, getCurrentScope(), CE);
// Custom derivative templates can be written in a
// general way that works for both vectorized and non-vectorized
// modes. We have to also look for the pushforward with the regular name.
if (!callDiff && m_DiffReq.Mode != DiffMode::forward) {
customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + "_pushforward";
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
FD->getDeclContext());
customPushforward, customDerivativeArgs, getCurrentScope(), CE);
}
if (!isLambda) {
// Check if it is a recursive call.
Expand Down Expand Up @@ -2316,8 +2314,7 @@ clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall(
clad::utils::ComputeEffectiveFnName(CE->getConstructor()) +
GetPushForwardFunctionSuffix();
Expr* pushforwardCall = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforwardName, customPushforwardArgs, getCurrentScope(),
CE->getConstructor()->getDeclContext());
customPushforwardName, customPushforwardArgs, getCurrentScope(), CE);
return pushforwardCall;
}
} // end namespace clad
22 changes: 21 additions & 1 deletion lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,29 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {

Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, const clang::DeclContext* originalFnDC,
clang::Scope* S, const clang::Expr* callSite,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/,
Expr* CUDAExecConfig /*=nullptr*/) {
const DeclContext* originalFnDC = nullptr;

// FIXME: callSite must not be numm but it comes when we try to build
// a numerical diff call. We should merge both paths and remove the
// special branches being taken for propagators and numerical diff.
if (callSite) {
// Check if the callSite is not associated with a shadow declaration.
if (auto* ME = dyn_cast<CXXMemberCallExpr>(callSite)) {
originalFnDC = ME->getMethodDecl()->getParent();
} else if (auto* CE = dyn_cast<CallExpr>(callSite)) {
const Expr* Callee = CE->getCallee()->IgnoreParenCasts();
if (auto* DRE = dyn_cast<DeclRefExpr>(Callee))
originalFnDC = DRE->getFoundDecl()->getDeclContext();
else if (auto* MemberE = dyn_cast<MemberExpr>(Callee))
originalFnDC = MemberE->getFoundDecl().getDecl()->getDeclContext();
} else if (auto* CtorExpr = dyn_cast<CXXConstructExpr>(callSite)) {
originalFnDC = CtorExpr->getConstructor()->getDeclContext();
}
}

CXXScopeSpec SS;
LookupResult R = LookupCustomDerivativeOrNumericalDiff(
Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist);
Expand Down
33 changes: 18 additions & 15 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1827,8 +1827,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DerivedCallArgs.front()->getType(), m_Context, 1));
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
FD->getDeclContext(),
customPushforward, pushforwardCallArgs, getCurrentScope(), CE,
/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (OverloadedDerivedFn)
Expand Down Expand Up @@ -1931,8 +1930,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
FD->getDeclContext(),
customPullback, pullbackCallArgs, getCurrentScope(), CE,
/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (baseDiff.getExpr())
Expand Down Expand Up @@ -2064,7 +2062,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
baseDiff.getExpr_dx(), Loc));

if (Expr* customForwardPassCE =
BuildCallToCustomForwPassFn(FD, CallArgs, CallArgDx, baseExpr)) {
BuildCallToCustomForwPassFn(CE, CallArgs, CallArgDx, baseExpr)) {
if (!utils::isNonConstReferenceType(returnType) &&
!returnType->isPointerType())
return StmtDiff{customForwardPassCE};
Expand Down Expand Up @@ -2214,7 +2212,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::string Name = "central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*callSite=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false, CUDAExecConfig);
}
Expand Down Expand Up @@ -4247,8 +4245,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::string customPullbackName = "constructor_pullback";
if (Expr* customPullbackCall =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullbackName, pullbackArgs, getCurrentScope(),
CE->getConstructor()->getDeclContext())) {
customPullbackName, pullbackArgs, getCurrentScope(), CE)) {
curRevBlock.insert(it, customPullbackCall);
if (m_TrackConstructorPullbackInfo) {
setConstructorPullbackCallInfo(llvm::cast<CallExpr>(customPullbackCall),
Expand Down Expand Up @@ -4278,9 +4275,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// SomeClass _d_c = _t0.adjoint;
// SomeClass c = _t0.value;
// ```
if (Expr* customReverseForwFnCall = BuildCallToCustomForwPassFn(
CE->getConstructor(), primalArgs, reverseForwAdjointArgs,
/*baseExpr=*/nullptr)) {
if (Expr* customReverseForwFnCall =
BuildCallToCustomForwPassFn(CE, primalArgs, reverseForwAdjointArgs,
/*baseExpr=*/nullptr)) {
if (RD->isAggregate()) {
SmallString<128> Name_class;
llvm::raw_svector_ostream OS_class(Name_class);
Expand Down Expand Up @@ -4555,16 +4552,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

Expr* ReverseModeVisitor::BuildCallToCustomForwPassFn(
const FunctionDecl* FD, llvm::ArrayRef<Expr*> primalArgs,
const Expr* callSite, llvm::ArrayRef<Expr*> primalArgs,
llvm::ArrayRef<clang::Expr*> derivedArgs, Expr* baseExpr) {
std::string forwPassFnName =
clad::utils::ComputeEffectiveFnName(FD) + "_reverse_forw";
llvm::SmallVector<Expr*, 4> args;
if (baseExpr) {
baseExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, baseExpr,
m_DiffReq->getLocation());
args.push_back(baseExpr);
}
const FunctionDecl* FD = nullptr;
if (auto* CE = dyn_cast<CallExpr>(callSite))
FD = CE->getDirectCallee();
else
FD = cast<CXXConstructExpr>(callSite)->getConstructor();

if (auto CD = llvm::dyn_cast<CXXConstructorDecl>(FD)) {
const RecordDecl* RD = CD->getParent();
QualType constructorReverseForwTagT =
Expand All @@ -4582,9 +4583,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
args.append(primalArgs.begin(), primalArgs.end());
args.append(derivedArgs.begin(), derivedArgs.end());
std::string forwPassFnName =
clad::utils::ComputeEffectiveFnName(FD) + "_reverse_forw";
Expr* customForwPassCE =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
forwPassFnName, args, getCurrentScope(), FD->getDeclContext());
forwPassFnName, args, getCurrentScope(), callSite);
return customForwPassCE;
}

Expand Down
25 changes: 25 additions & 0 deletions test/FirstDerivative/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,28 @@
#include "../TestUtils.h"
extern "C" int printf(const char* fmt, ...);


namespace N {
namespace impl {
double sq(double x);
}
using impl::sq; // using shadow
}

namespace clad {
namespace custom_derivatives {
namespace N {
clad::ValueAndPushforward<double, double> sq_pushforward(double x, double *d_x) {
return { x * x, 2 * x };
}
}
}
}

float f0 (float x) {
return N::sq(x); // must find the sq_pushforward.
}

namespace clad{
namespace custom_derivatives{
float f1_darg0(float x) {
Expand Down Expand Up @@ -296,6 +318,9 @@ int main () { //expected-no-diagnostics
double d_result[2];
int i_result[1];

auto f0_darg0 = clad::differentiate(f0, 0);
printf("Result is = %f\n", f0_darg0.execute(2)); // CHECK-EXEC: Result is = -0.952413

auto f1_darg0 = clad::differentiate(f1, 0);
printf("Result is = %f\n", f1_darg0.execute(60)); // CHECK-EXEC: Result is = -0.952413

Expand Down
2 changes: 1 addition & 1 deletion test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ double f19(double a, double b) {
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: double _r2 = 0.;
//CHECK-NEXT: clad::custom_derivatives::fma_pullback(a, b, b, 1, &_r0, &_r1, &_r2);
//CHECK-NEXT: clad::custom_derivatives::std::fma_pullback(a, b, b, 1, &_r0, &_r1, &_r2);
//CHECK-NEXT: *_d_a += _r0;
//CHECK-NEXT: *_d_b += _r1;
//CHECK-NEXT: *_d_b += _r2;
Expand Down
6 changes: 3 additions & 3 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ void f_norm_grad(double x,
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: double _r5 = 0.;
//CHECK-NEXT: clad::custom_derivatives::pow_pullback(sum_of_powers(x, y, z, d), 1 / d, 1, &_r0, &_r5);
//CHECK-NEXT: clad::custom_derivatives::std::pow_pullback(sum_of_powers(x, y, z, d), 1 / d, 1, &_r0, &_r5);
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: double _r2 = 0.;
//CHECK-NEXT: double _r3 = 0.;
Expand All @@ -430,10 +430,10 @@ void f_sin_grad(double x, double y, double *_d_x, double *_d_y);
//CHECK-NEXT: double _t0 = (std::sin(x) + std::sin(y));
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: _r0 += 1 * (x + y) * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward;
//CHECK-NEXT: _r0 += 1 * (x + y) * clad::custom_derivatives::std::sin_pushforward(x, 1.).pushforward;
//CHECK-NEXT: *_d_x += _r0;
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: _r1 += 1 * (x + y) * clad::custom_derivatives::sin_pushforward(y, 1.).pushforward;
//CHECK-NEXT: _r1 += 1 * (x + y) * clad::custom_derivatives::std::sin_pushforward(y, 1.).pushforward;
//CHECK-NEXT: *_d_y += _r1;
//CHECK-NEXT: *_d_x += _t0 * 1;
//CHECK-NEXT: *_d_y += _t0 * 1;
Expand Down
4 changes: 2 additions & 2 deletions test/Jacobian/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ void fn1(double i, double j, double* output) {
// CHECK-NEXT: clad::array<double> _d_vector_i = clad::one_hot_vector(indepVarCount, {{0U|0UL|0ULL}});
// CHECK-NEXT: clad::array<double> _d_vector_j = clad::one_hot_vector(indepVarCount, {{1U|1UL|1ULL}});
// CHECK-NEXT: *_d_vector_output = clad::identity_matrix(_d_vector_output->rows(), indepVarCount, {{2U|2UL|2ULL}});
// CHECK-NEXT: {{.*}} _t0 = clad::custom_derivatives::pow_pushforward(i, j, _d_vector_i, _d_vector_j);
// CHECK-NEXT: {{.*}} _t0 = clad::custom_derivatives::std::pow_pushforward(i, j, _d_vector_i, _d_vector_j);
// CHECK-NEXT: *_d_vector_output[0] = _t0.pushforward;
// CHECK-NEXT: output[0] = _t0.value;
// CHECK-NEXT: {{.*}} _t1 = clad::custom_derivatives::pow_pushforward(j, i, _d_vector_j, _d_vector_i);
// CHECK-NEXT: {{.*}} _t1 = clad::custom_derivatives::std::pow_pushforward(j, i, _d_vector_j, _d_vector_i);
// CHECK-NEXT: *_d_vector_output[1] = _t1.pushforward;
// CHECK-NEXT: output[1] = _t1.value;
// CHECK-NEXT: }
Expand Down
8 changes: 4 additions & 4 deletions test/NestedCalls/NestedCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ int main () { // expected-no-diagnostics
// CHECK: clad::ValueAndPushforward<double, double> sq_pushforward(double x, double _d_x);

// CHECK: clad::ValueAndPushforward<double, double> one_pushforward(double x, double _d_x) {
// CHECK-NEXT: ValueAndPushforward<double, double> _t0 = clad::custom_derivatives::sin_pushforward(x, _d_x);
// CHECK-NEXT: ValueAndPushforward<double, double> _t0 = clad::custom_derivatives::std::sin_pushforward(x, _d_x);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t1 = sq_pushforward(_t0.value, _t0.pushforward);
// CHECK-NEXT: ValueAndPushforward<double, double> _t2 = clad::custom_derivatives::cos_pushforward(x, _d_x);
// CHECK-NEXT: ValueAndPushforward<double, double> _t2 = clad::custom_derivatives::std::cos_pushforward(x, _d_x);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t3 = sq_pushforward(_t2.value, _t2.pushforward);
// CHECK-NEXT: return {_t1.value + _t3.value, _t1.pushforward + _t3.pushforward};
// CHECK-NEXT: }
Expand All @@ -71,12 +71,12 @@ int main () { // expected-no-diagnostics
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: sq_pullback(std::sin(x), _d_y, &_r0);
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: _r1 += _r0 * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward;
//CHECK-NEXT: _r1 += _r0 * clad::custom_derivatives::std::sin_pushforward(x, 1.).pushforward;
//CHECK-NEXT: *_d_x += _r1;
//CHECK-NEXT: double _r2 = 0.;
//CHECK-NEXT: sq_pullback(std::cos(x), _d_y, &_r2);
//CHECK-NEXT: double _r3 = 0.;
//CHECK-NEXT: _r3 += _r2 * clad::custom_derivatives::cos_pushforward(x, 1.).pushforward;
//CHECK-NEXT: _r3 += _r2 * clad::custom_derivatives::std::cos_pushforward(x, 1.).pushforward;
//CHECK-NEXT: *_d_x += _r3;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down

0 comments on commit 27958c3

Please sign in to comment.