diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index fda66a862..bb8355488 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -663,8 +663,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Prepare Arguments and Parameters to enzyme_autodiff llvm::SmallVector enzymeArgs; llvm::SmallVector enzymeParams; - llvm::SmallVector enzymeRealParams; - llvm::SmallVector enzymeRealParamsDerived; + llvm::SmallVector enzymeRealParamsDerived; // First add the function itself as a parameter/argument enzymeArgs.push_back(BuildDeclRef(const_cast(m_Function))); @@ -675,23 +674,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Add rest of the parameters/arguments for (unsigned i = 0; i < numParams; i++) { + ParmVarDecl* param = paramsRef[i]; // First Add the original parameter - enzymeArgs.push_back(BuildDeclRef(paramsRef[i])); + enzymeArgs.push_back(BuildDeclRef(param)); enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( - fdDeclContext, noLoc, paramsRef[i]->getType())); + fdDeclContext, noLoc, param->getType())); QualType paramType = origParams[i]->getOriginalType(); // If original parameter is of a differentiable real type(but not // array/pointer), then add it to the list of params whose gradient must // be extracted later from the EnzymeGradient structure if (paramType->isRealFloatingType()) { - enzymeRealParams.push_back(paramsRef[i]); - enzymeRealParamsDerived.push_back(paramsRef[numParams + i]); + enzymeRealParamsDerived.push_back(m_Variables[param]); } else if (utils::isArrayOrPointerType(paramType)) { // Add the corresponding array/pointer variable - enzymeArgs.push_back(BuildDeclRef(paramsRef[numParams + i])); + enzymeArgs.push_back(m_Variables[param]); enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( - fdDeclContext, noLoc, paramsRef[numParams + i]->getType())); + fdDeclContext, noLoc, m_Variables[param]->getType())); } } @@ -700,12 +699,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, enzymeParamsType.push_back(i->getType()); QualType QT; - if (!enzymeRealParams.empty()) { + if (!enzymeRealParamsDerived.empty()) { // Find the EnzymeGradient datastructure auto* gradDecl = LookupTemplateDeclInCladNamespace("EnzymeGradient"); TemplateArgumentListInfo TLI{}; - llvm::APSInt argValue(std::to_string(enzymeRealParams.size())); + llvm::APSInt argValue(std::to_string(enzymeRealParamsDerived.size())); TemplateArgument TA(m_Context, argValue, m_Context.UnsignedIntTy); TLI.addArgument(TemplateArgumentLoc(TA, TemplateArgumentLocInfo())); @@ -730,13 +729,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Prepare the statements that assign the gradients to // non array/pointer type parameters of the original function - if (!enzymeRealParams.empty()) { + if (!enzymeRealParamsDerived.empty()) { auto* gradDeclStmt = BuildVarDecl(QT, "grad", enzymeCall, true); addToCurrentBlock(BuildDeclStmt(gradDeclStmt), direction::forward); - for (unsigned i = 0; i < enzymeRealParams.size(); i++) { - auto* LHSExpr = - BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsDerived[i])); + for (unsigned i = 0; i < enzymeRealParamsDerived.size(); i++) { + auto* LHSExpr = Clone(enzymeRealParamsDerived[i]); auto* ME = utils::BuildMemberExpr(m_Sema, getCurrentScope(), BuildDeclRef(gradDeclStmt), "d_arr"); @@ -753,7 +751,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } else { // Add Function call to block - Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); addToCurrentBlock(enzymeCall); } } diff --git a/test/Enzyme/FunctionPrototypesReverseMode.C b/test/Enzyme/FunctionPrototypesReverseMode.C index 0407b9fd5..ff2df3358 100644 --- a/test/Enzyme/FunctionPrototypesReverseMode.C +++ b/test/Enzyme/FunctionPrototypesReverseMode.C @@ -30,7 +30,7 @@ double f3(double* arr, int n){ return sum; } -// CHECK: void f3_grad_enzyme(double *arr, int n, double *_d_arr, int *_d_n) { +// CHECK: void f3_grad_enzyme(double *arr, int n, double *_d_arr) { // CHECK-NEXT: __enzyme_autodiff_f3(f3, arr, _d_arr, n); // CHECK-NEXT: } @@ -45,7 +45,7 @@ double f4(double* arr1, int n, double* arr2, int m){ return sum; } -// CHECK: void f4_grad_enzyme(double *arr1, int n, double *arr2, int m, double *_d_arr1, int *_d_n, double *_d_arr2, int *_d_m) { +// CHECK: void f4_grad_enzyme(double *arr1, int n, double *arr2, int m, double *_d_arr1, double *_d_arr2) { // CHECK-NEXT: __enzyme_autodiff_f4(f4, arr1, _d_arr1, n, arr2, _d_arr2, m); // CHECK-NEXT: } @@ -57,7 +57,7 @@ double f5(double arr[], double x,int n,double y){ return res; } -// CHECK: void f5_grad_enzyme(double arr[], double x, int n, double y, double *_d_arr, double *_d_x, int *_d_n, double *_d_y) { +// CHECK: void f5_grad_enzyme(double arr[], double x, int n, double y, double *_d_arr, double *_d_x, double *_d_y) { // CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_f5(f5, arr, _d_arr, x, n, y); // CHECK-NEXT: *_d_x = grad.d_arr[0U]; // CHECK-NEXT: *_d_y = grad.d_arr[1U]; @@ -82,30 +82,27 @@ int main() { auto f3_grad=clad::gradient(f3); double f3_list[3]={3,4,5}; double f3_res[3]={0}; - int f3_dn=0; - f3_grad.execute(f3_list,3,f3_res,&f3_dn); - printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n = %d\n",f3_res[0],f3_res[1],f3_res[2],f3_dn); - //CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n = 0 + f3_grad.execute(f3_list,3,f3_res); + printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f\n",f3_res[0],f3_res[1],f3_res[2]); + //CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00 auto f4_grad=clad::gradient(f4); double f4_list1[3]={3,4,5}; double f4_list2[2]={1,2}; double f4_res1[3]={0}; double f4_res2[2]={0}; - int f4_dn1=0,f4_dn2=0; - f4_grad.execute(f4_list1,3,f4_list2,2,f4_res1,&f4_dn1,f4_res2,&f4_dn2); - printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d\n",f4_res1[0],f4_res1[1],f4_res1[2],f4_dn1); - //CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n1 = 0 - printf("d_y1 = %.2f, d_y2 = %.2f, d_n2 = %d\n",f4_res2[0],f4_res2[1],f4_dn2); - //CHECK-EXEC: d_y1 = 2.00, d_y2 = 4.00, d_n2 = 0 + f4_grad.execute(f4_list1,3,f4_list2,2,f4_res1,f4_res2); + printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f\n",f4_res1[0],f4_res1[1],f4_res1[2]); + //CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00 + printf("d_y1 = %.2f, d_y2 = %.2f\n",f4_res2[0],f4_res2[1]); + //CHECK-EXEC: d_y1 = 2.00, d_y2 = 4.00 auto f5_grad=clad::gradient(f5); double f5_list[3]={3,4,5}; double f5_res[3]={0}; double f5_x=10.0,f5_dx=0,f5_y=5,f5_dy; - int f5_dn=0; - f5_grad.execute(f5_list,f5_x,3,f5_y,f5_res,&f5_dx,&f5_dn,&f5_dy); - printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d, d_x = %.2f, d_y = %.2f\n",f5_res[0],f5_res[1],f5_res[2],f5_dn, f5_dx, f5_dy); - //CHECK-EXEC: d_x1 = 50.00, d_x2 = 50.00, d_x3 = 50.00, d_n1 = 0, d_x = 60.00, d_y = 120.00 + f5_grad.execute(f5_list,f5_x,3,f5_y,f5_res,&f5_dx,&f5_dy); + printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_x = %.2f, d_y = %.2f\n",f5_res[0],f5_res[1],f5_res[2],f5_dx, f5_dy); + //CHECK-EXEC: d_x1 = 50.00, d_x2 = 50.00, d_x3 = 50.00, d_x = 60.00, d_y = 120.00 }