From 693b1bde992edd81ec89b708eb96c6af0793dd69 Mon Sep 17 00:00:00 2001 From: Christina Koutsou <74819775+kchristin22@users.noreply.github.com> Date: Sat, 23 Nov 2024 08:20:01 +0200 Subject: [PATCH] Clone base decl when having an anonymous struct or union (#1152) Fixes #1151 --- lib/Differentiator/ReverseModeVisitor.cpp | 13 ++-- test/Gradient/UserDefinedTypes.C | 81 +++++++++++++++++++++++ 2 files changed, 90 insertions(+), 4 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 5909ca6e4..9ecf043c8 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -3314,8 +3314,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto* field = ME->getMemberDecl(); assert(!isa(field) && "CXXMethodDecl nodes not supported yet!"); - MemberExpr* clonedME = utils::BuildMemberExpr( - m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); + Expr* clonedME = baseDiff.getExpr(); + llvm::StringRef fieldName = field->getName(); + if (baseDiff.getExpr() && !fieldName.empty()) + clonedME = utils::BuildMemberExpr(m_Sema, getCurrentScope(), + baseDiff.getExpr(), fieldName); if (clad::utils::hasNonDifferentiableAttribute(ME)) { auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, /*val=*/0); @@ -3323,8 +3326,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } if (!baseDiff.getExpr_dx()) return {clonedME, nullptr}; - MemberExpr* derivedME = utils::BuildMemberExpr( - m_Sema, getCurrentScope(), baseDiff.getExpr_dx(), field->getName()); + Expr* derivedME = baseDiff.getExpr_dx(); + if (!fieldName.empty()) + derivedME = utils::BuildMemberExpr(m_Sema, getCurrentScope(), + baseDiff.getExpr_dx(), fieldName); if (dfdx()) { Expr* addAssign = BuildOp(BinaryOperatorKind::BO_AddAssign, derivedME, dfdx()); diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index 6bc61395a..344f185cb 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -404,6 +404,69 @@ MyStruct fn12(MyStruct s) { // CHECK-NEXT: } // CHECK-NEXT:} +typedef int Fint; +typedef union Findex +{ + struct + { + Fint j, k, l; + }; + Fint dim[3]; +} Findex; + +void fn13(double *x, double *y, int size) +{ + Findex p; + + for (p.j = 0; p.j < size; p.j += 1) + { + y[p.j] = 2.0 * x[p.j]; + } +} + +// CHECK: void fn13_grad_0_1(double *x, double *y, int size, double *_d_x, double *_d_y) { +// CHECK-NEXT: int _d_size = 0; +// CHECK-NEXT: Fint _t1; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: Findex _d_p({}); +// CHECK-NEXT: Findex p; +// CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}}; +// CHECK-NEXT: _t1 = p.j; +// CHECK-NEXT: for (p.j = 0; ; clad::push(_t2, p.j) , (p.j += 1)) { +// CHECK-NEXT: { +// CHECK-NEXT: if (!(p.j < size)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: _t0++; +// CHECK-NEXT: clad::push(_t3, y[p.j]); +// CHECK-NEXT: y[p.j] = 2. * x[p.j]; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: { +// CHECK-NEXT: if (!_t0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p.j = clad::pop(_t2); +// CHECK-NEXT: Fint _r_d1 = _d_p.j; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: y[p.j] = clad::pop(_t3); +// CHECK-NEXT: double _r_d2 = _d_y[p.j]; +// CHECK-NEXT: _d_y[p.j] = 0.; +// CHECK-NEXT: _d_x[p.j] += 2. * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p.j = _t1; +// CHECK-NEXT: Fint _r_d0 = _d_p.j; +// CHECK-NEXT: _d_p.j = 0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT:} + void print(const Tangent& t) { for (int i = 0; i < 5; ++i) { printf("%.2f", t.data[i]); @@ -416,6 +479,16 @@ void print(const MyStruct& s) { printf("{%.2f, %.2f}\n", s.a, s.b); } +void printArray(double* arr, int size) { + printf("{"); + for (int i = 0; i < size; ++i) { + printf("%.2f", arr[i]); + if (i != size - 1) + printf(", "); + } + printf("}\n"); +} + int main() { pairdd p(3, 5), d_p; double i = 3, d_i, d_j; @@ -454,6 +527,14 @@ int main() { auto fn12_test = clad::gradient(fn12); fn12_test.execute(s, &d_s); print(d_s); // CHECK-EXEC: {2.00, 2.00} + + auto fn13_test = clad::gradient(fn13, "x, y"); + double x[3] = {1.0, 2.0, 3.0}, y[3] = {0.0, 0.0, 0.0}; + double d_x[3] = {0.0, 0.0, 0.0}, d_y[3] = {1.0, 1.0, 1.0}; + int size = 3; + fn13_test.execute(x, y, 3, d_x, d_y); + printArray(d_x, size); // CHECK-EXEC: {2.00, 2.00, 2.00} + printArray(d_y, size); // CHECK-EXEC: {0.00, 0.00, 0.00} } // CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {