From 01c80732a1e487ce9531a0783f2b3e3f92c757d6 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Mon, 19 Feb 2024 15:31:24 +0100 Subject: [PATCH] Fix struct init using initializer lists --- .../Differentiator/BaseForwardModeVisitor.h | 1 + .../clad/Differentiator/ReverseModeVisitor.h | 2 + lib/Differentiator/BaseForwardModeVisitor.cpp | 5 + lib/Differentiator/ReverseModeVisitor.cpp | 114 ++++++++++-------- test/ForwardMode/Pointer.C | 25 ++++ test/Gradient/Pointers.C | 40 +++++- 6 files changed, 133 insertions(+), 54 deletions(-) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 375ae88d5..4ca1fcae3 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -106,6 +106,7 @@ class BaseForwardModeVisitor StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); + StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE); virtual clang::QualType GetPushForwardDerivativeType(clang::QualType ParamType); diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index b38749bf1..dff250161 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -358,6 +358,8 @@ namespace clad { StmtDiff VisitForStmt(const clang::ForStmt* FS); StmtDiff VisitIfStmt(const clang::IfStmt* If); StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE); + StmtDiff + VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE); StmtDiff VisitInitListExpr(const clang::InitListExpr* ILE); StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL); StmtDiff VisitMemberExpr(const clang::MemberExpr* ME); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 37f32bf81..68979bcb3 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1445,6 +1445,11 @@ BaseForwardModeVisitor::VisitImplicitCastExpr(const ImplicitCastExpr* ICE) { return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx()); } +StmtDiff BaseForwardModeVisitor::VisitImplicitValueInitExpr( + const ImplicitValueInitExpr* E) { + return StmtDiff(Clone(E), Clone(E)); +} + StmtDiff BaseForwardModeVisitor::VisitCXXDefaultArgExpr(const CXXDefaultArgExpr* DE) { // FIXME: Shouldn't we simply clone the CXXDefaultArgExpr? diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index a9529270c..f42e8a2f4 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1247,6 +1247,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get(); return StmtDiff(clonedILE); } + // Check if type is a CXXRecordDecl and a struct. + if (!isCladValueAndPushforwardType(ILEType) && ILEType->isRecordType() && + ILEType->getAsCXXRecordDecl()->isStruct()) { + for (unsigned i = 0, e = ILE->getNumInits(); i < e; i++) { + // fetch ith field of the struct. + auto field_iterator = ILEType->getAsCXXRecordDecl()->field_begin(); + std::advance(field_iterator, i); + Expr* member_acess = utils::BuildMemberExpr( + m_Sema, getCurrentScope(), dfdx(), (*field_iterator)->getName()); + clonedExprs[i] = Visit(ILE->getInit(i), member_acess).getExpr(); + } + Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get(); + return StmtDiff(clonedILE); + } + // FIXME: This is a makeshift arrangement to differentiate an InitListExpr // that represents a ValueAndPushforward type. Ideally this must be // differentiated at VisitCXXConstructExpr @@ -2582,11 +2597,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool isPointerType = VD->getType()->isPointerType(); bool isInitializedByNewExpr = false; // Check if the variable is pointer type and initialized by new expression - if (isPointerType && VD->getInit()) { - if (isa(VD->getInit())) { - isInitializedByNewExpr = true; - } - } + if (isPointerType && VD->getInit() && isa(VD->getInit())) + isInitializedByNewExpr = true; // VDDerivedInit now serves two purposes -- as the initial derivative value // or the size of the derivative array -- depending on the primal type. @@ -2655,7 +2667,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // if VD is a pointer type, then the initial value is set to the derived // expression of the corresponding pointer type. else if (isPointerType && VD->getInit()) { - initDiff = Visit(VD->getInit()); VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); // If it's a pointer to a constant type, then remove the constness. if (VD->getType()->getPointeeType().isConstQualified()) { @@ -2679,10 +2690,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // need to call `Visit` since non-local variables are not differentiated. if (!isDerivativeOfRefType && !(isPointerType && !isInitializedByNewExpr)) { Expr* derivedE = BuildDeclRef(VDDerived); - if (isInitializedByNewExpr) { - // derivedE should be dereferenced. + if (isInitializedByNewExpr) derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); - } if (VD->getInit()) { if (isa(VD->getInit())) initDiff = Visit(VD->getInit()); @@ -2709,6 +2718,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, getZeroInit(VDDerivedType)); addToCurrentBlock(assignToZero, direction::reverse); } + } else if (isPointerType && VD->getInit()) { + initDiff = Visit(VD->getInit()); } VarDecl* VDClone = nullptr; Expr* derivedVDE = BuildDeclRef(VDDerived); @@ -2926,6 +2937,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Visit(ICE->getSubExpr(), dfdx()); } + StmtDiff ReverseModeVisitor::VisitImplicitValueInitExpr( + const ImplicitValueInitExpr* IVIE) { + return {Clone(IVIE), Clone(IVIE)}; + } + StmtDiff ReverseModeVisitor::VisitMemberExpr(const MemberExpr* ME) { auto baseDiff = VisitWithExplicitNoDfDx(ME->getBase()); auto* field = ME->getMemberDecl(); @@ -3722,47 +3738,47 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitCXXNewExpr(const clang::CXXNewExpr* CNE) { - StmtDiff initializerDiff; - if (CNE->hasInitializer()) - initializerDiff = Visit(CNE->getInitializer(), dfdx()); - - Expr* clonedArraySizeE = nullptr; - Expr* derivedArraySizeE = nullptr; - if (CNE->getArraySize()) { - clonedArraySizeE = - Visit(clad_compat::ArraySize_GetValue(CNE->getArraySize())).getExpr(); - // Array size is a non-differentiable expression, thus the original value - // should be used in both the cloned and the derived statements. - derivedArraySizeE = Clone(clonedArraySizeE); - } - Expr* clonedNewE = utils::BuildCXXNewExpr( - m_Sema, CNE->getAllocatedType(), clonedArraySizeE, - initializerDiff.getExpr(), CNE->getAllocatedTypeSourceInfo()); - Expr* derivedNewE = utils::BuildCXXNewExpr( - m_Sema, CNE->getAllocatedType(), derivedArraySizeE, - initializerDiff.getExpr_dx(), CNE->getAllocatedTypeSourceInfo()); - return {clonedNewE, derivedNewE}; -} + StmtDiff initializerDiff; + if (CNE->hasInitializer()) + initializerDiff = Visit(CNE->getInitializer(), dfdx()); + + Expr* clonedArraySizeE = nullptr; + Expr* derivedArraySizeE = nullptr; + if (CNE->getArraySize()) { + clonedArraySizeE = + Visit(clad_compat::ArraySize_GetValue(CNE->getArraySize())).getExpr(); + // Array size is a non-differentiable expression, thus the original value + // should be used in both the cloned and the derived statements. + derivedArraySizeE = Clone(clonedArraySizeE); + } + Expr* clonedNewE = utils::BuildCXXNewExpr( + m_Sema, CNE->getAllocatedType(), clonedArraySizeE, + initializerDiff.getExpr(), CNE->getAllocatedTypeSourceInfo()); + Expr* derivedNewE = utils::BuildCXXNewExpr( + m_Sema, CNE->getAllocatedType(), derivedArraySizeE, + initializerDiff.getExpr_dx(), CNE->getAllocatedTypeSourceInfo()); + return {clonedNewE, derivedNewE}; + } -StmtDiff -ReverseModeVisitor::VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE) { - StmtDiff argDiff = Visit(CDE->getArgument()); - Expr* clonedDeleteE = - m_Sema - .ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(), - argDiff.getExpr()) - .get(); - Expr* derivedDeleteE = - m_Sema - .ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(), - argDiff.getExpr_dx()) - .get(); - // create a compound statement containing both the cloned and the derived - // delete expressions. - CompoundStmt* CS = MakeCompoundStmt({clonedDeleteE, derivedDeleteE}); - m_DeallocExprs.push_back(CS); - return {nullptr, nullptr}; -} + StmtDiff + ReverseModeVisitor::VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE) { + StmtDiff argDiff = Visit(CDE->getArgument()); + Expr* clonedDeleteE = + m_Sema + .ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(), + argDiff.getExpr()) + .get(); + Expr* derivedDeleteE = + m_Sema + .ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(), + argDiff.getExpr_dx()) + .get(); + // create a compound statement containing both the cloned and the derived + // delete expressions. + CompoundStmt* CS = MakeCompoundStmt({clonedDeleteE, derivedDeleteE}); + m_DeallocExprs.push_back(CS); + return {nullptr, nullptr}; + } // FIXME: Add support for differentiating calls to constructors. // We currently assume that constructor arguments are non-differentiable. diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C index de181231d..25fce39b3 100644 --- a/test/ForwardMode/Pointer.C +++ b/test/ForwardMode/Pointer.C @@ -110,16 +110,41 @@ double fn5(double i, double j) { // CHECK-NEXT: return *(_d_arr + idx1) + *(_d_arr + idx2); // CHECK-NEXT: } +struct T { + double i; + int j; +}; + +double fn6 (double i) { + T* t = new T{i}; + double res = t->i; + delete t; + return res; +} + +// CHECK: double fn6_darg0(double i) { +// CHECK-NEXT: double _d_i = 1; +// CHECK-NEXT: T *_d_t = new T({_d_i, /*implicit*/(int)0}); +// CHECK-NEXT: T *t = new T({i, /*implicit*/(int)0}); +// CHECK-NEXT: double _d_res = _d_t->i; +// CHECK-NEXT: double res = t->i; +// CHECK-NEXT: delete _d_t; +// CHECK-NEXT: delete t; +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + int main() { INIT_DIFFERENTIATE(fn1, "i"); INIT_DIFFERENTIATE(fn2, "i"); INIT_DIFFERENTIATE(fn3, "i"); INIT_DIFFERENTIATE(fn4, "i"); INIT_DIFFERENTIATE(fn5, "i"); + INIT_DIFFERENTIATE(fn6, "i"); TEST_DIFFERENTIATE(fn1, 3, 5); // CHECK-EXEC: {5.00} TEST_DIFFERENTIATE(fn2, 3, 5); // CHECK-EXEC: {5.00} TEST_DIFFERENTIATE(fn3, 3, 5); // CHECK-EXEC: {6.00} TEST_DIFFERENTIATE(fn4, 3, 5); // CHECK-EXEC: {16.00} TEST_DIFFERENTIATE(fn5, 3, 5); // CHECK-EXEC: {57.00} + TEST_DIFFERENTIATE(fn6, 3); // CHECK-EXEC: {1.00} } diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index 11a66705c..cb2b66ee9 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -395,14 +395,40 @@ double newAndDeletePointer(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: * _d_j += *_d_q; // CHECK-NEXT: * _d_i += *_d_p; -// CHECK-NEXT: delete [] r; -// CHECK-NEXT: delete [] _d_r; -// CHECK-NEXT: delete q; -// CHECK-NEXT: delete _d_q; // CHECK-NEXT: delete p; // CHECK-NEXT: delete _d_p; +// CHECK-NEXT: delete q; +// CHECK-NEXT: delete _d_q; +// CHECK-NEXT: delete [] r; +// CHECK-NEXT: delete [] _d_r; +// CHECK-NEXT: } + +struct T { + double x; + int y; +}; + +double structPointer (double x) { + T* t = new T{x}; + double res = t->x; + delete t; + return res; +} + +// CHECK: void structPointer_grad(double x, clad::array_ref _d_x) { +// CHECK-NEXT: T *_d_t = 0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: _d_t = new T; +// CHECK-NEXT: T *t = new T({x, /*implicit*/(int)0}); +// CHECK-NEXT: double res = t->x; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: _d_t->x += _d_res; +// CHECK-NEXT: * _d_x += *_d_t.x; +// CHECK-NEXT: delete t; +// CHECK-NEXT: delete _d_t; // CHECK-NEXT: } - #define NON_MEM_FN_TEST(var)\ res[0]=0;\ @@ -503,4 +529,8 @@ int main() { double d_i = 0, d_j = 0; d_newAndDeletePointer.execute(5, 7, &d_i, &d_j); printf("%.2f %.2f\n", d_i, d_j); // CHECK-EXEC: 9.00 7.00 + + auto d_structPointer = clad::gradient(structPointer); + double d_x = 0; + d_structPointer.execute(5, &d_x); } \ No newline at end of file