From 7e8c08d7127c59ab2c822442dacecc42d228c351 Mon Sep 17 00:00:00 2001 From: kchristin Date: Fri, 22 Nov 2024 00:06:41 +0200 Subject: [PATCH] Fix return stmt cast to 1 when it's not a scalar --- lib/Differentiator/ReverseModeVisitor.cpp | 11 +++++++---- test/ForwardMode/STLCustomDerivatives.C | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 7a6469333..ab5b11605 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1216,7 +1216,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const Expr* value = RS->getRetValue(); QualType type = value->getType(); auto* dfdf = m_Pullback; - if (dfdf && (isa(dfdf) || isa(dfdf))) { + if (dfdf && (isa(dfdf) || isa(dfdf)) && + type->isScalarType()) { ExprResult tmp = dfdf; dfdf = m_Sema .ImpCastExprToType(tmp.get(), type, @@ -1277,6 +1278,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) { + if (!dfdx()) + return StmtDiff(Clone(ILE)); QualType ILEType = ILE->getType(); llvm::SmallVector clonedExprs(ILE->getNumInits()); if (isArrayOrPointerType(ILEType)) { @@ -1302,12 +1305,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto field_iterator = ILEType->getAsCXXRecordDecl()->field_begin(); std::advance(field_iterator, i); Expr* member_acess = nullptr; - if (dfdx()) - member_acess = utils::BuildMemberExpr( - m_Sema, getCurrentScope(), dfdx(), (*field_iterator)->getName()); + member_acess = utils::BuildMemberExpr(m_Sema, getCurrentScope(), dfdx(), + (*field_iterator)->getName()); clonedExprs[i] = Visit(ILE->getInit(i), member_acess).getExpr(); } Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get(); + printf("before cloning\n"); return StmtDiff(clonedILE); } diff --git a/test/ForwardMode/STLCustomDerivatives.C b/test/ForwardMode/STLCustomDerivatives.C index 7ee45affa..35a321b2c 100644 --- a/test/ForwardMode/STLCustomDerivatives.C +++ b/test/ForwardMode/STLCustomDerivatives.C @@ -426,4 +426,4 @@ int main() { TEST_DIFFERENTIATE(fnArr1, 3); // CHECK-EXEC: {3.00} TEST_DIFFERENTIATE(fnArr2, 3); // CHECK-EXEC: {108.00} TEST_DIFFERENTIATE(fnTuple1, 3, 4); // CHECK-EXEC: {2.00} -} +} \ No newline at end of file