Skip to content

Commit

Permalink
Remove integral type params from differentiation with Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Apr 30, 2024
1 parent 685ad18 commit 63d07c5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 32 deletions.
27 changes: 12 additions & 15 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Prepare Arguments and Parameters to enzyme_autodiff
llvm::SmallVector<Expr*, 16> enzymeArgs;
llvm::SmallVector<ParmVarDecl*, 16> enzymeParams;
llvm::SmallVector<ParmVarDecl*, 16> enzymeRealParams;
llvm::SmallVector<ParmVarDecl*, 16> enzymeRealParamsDerived;
llvm::SmallVector<Expr*, 16> enzymeRealParamsDerived;

// First add the function itself as a parameter/argument
enzymeArgs.push_back(BuildDeclRef(const_cast<FunctionDecl*>(m_Function)));
Expand All @@ -675,23 +674,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Add rest of the parameters/arguments
for (unsigned i = 0; i < numParams; i++) {
ParmVarDecl* param = paramsRef[i];
// First Add the original parameter
enzymeArgs.push_back(BuildDeclRef(paramsRef[i]));
enzymeArgs.push_back(BuildDeclRef(param));
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramsRef[i]->getType()));
fdDeclContext, noLoc, param->getType()));

QualType paramType = origParams[i]->getOriginalType();
// If original parameter is of a differentiable real type(but not
// array/pointer), then add it to the list of params whose gradient must
// be extracted later from the EnzymeGradient structure
if (paramType->isRealFloatingType()) {
enzymeRealParams.push_back(paramsRef[i]);
enzymeRealParamsDerived.push_back(paramsRef[numParams + i]);
enzymeRealParamsDerived.push_back(m_Variables[param]);
} else if (utils::isArrayOrPointerType(paramType)) {
// Add the corresponding array/pointer variable
enzymeArgs.push_back(BuildDeclRef(paramsRef[numParams + i]));
enzymeArgs.push_back(m_Variables[param]);
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramsRef[numParams + i]->getType()));
fdDeclContext, noLoc, m_Variables[param]->getType()));
}
}

Expand All @@ -700,12 +699,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
enzymeParamsType.push_back(i->getType());

QualType QT;
if (!enzymeRealParams.empty()) {
if (!enzymeRealParamsDerived.empty()) {
// Find the EnzymeGradient datastructure
auto* gradDecl = LookupTemplateDeclInCladNamespace("EnzymeGradient");

TemplateArgumentListInfo TLI{};
llvm::APSInt argValue(std::to_string(enzymeRealParams.size()));
llvm::APSInt argValue(std::to_string(enzymeRealParamsDerived.size()));
TemplateArgument TA(m_Context, argValue, m_Context.UnsignedIntTy);
TLI.addArgument(TemplateArgumentLoc(TA, TemplateArgumentLocInfo()));

Expand All @@ -730,13 +729,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Prepare the statements that assign the gradients to
// non array/pointer type parameters of the original function
if (!enzymeRealParams.empty()) {
if (!enzymeRealParamsDerived.empty()) {
auto* gradDeclStmt = BuildVarDecl(QT, "grad", enzymeCall, true);
addToCurrentBlock(BuildDeclStmt(gradDeclStmt), direction::forward);

for (unsigned i = 0; i < enzymeRealParams.size(); i++) {
auto* LHSExpr =
BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsDerived[i]));
for (unsigned i = 0; i < enzymeRealParamsDerived.size(); i++) {
auto* LHSExpr = Clone(enzymeRealParamsDerived[i]);

auto* ME = utils::BuildMemberExpr(m_Sema, getCurrentScope(),
BuildDeclRef(gradDeclStmt), "d_arr");
Expand All @@ -753,7 +751,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
} else {
// Add Function call to block
Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs);
addToCurrentBlock(enzymeCall);
}
}
Expand Down
31 changes: 14 additions & 17 deletions test/Enzyme/FunctionPrototypesReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ double f3(double* arr, int n){
return sum;
}

// CHECK: void f3_grad_enzyme(double *arr, int n, double *_d_arr, int *_d_n) {
// CHECK: void f3_grad_enzyme_0(double *arr, int n, double *_d_arr) {
// CHECK-NEXT: __enzyme_autodiff_f3(f3, arr, _d_arr, n);
// CHECK-NEXT: }

Expand All @@ -45,7 +45,7 @@ double f4(double* arr1, int n, double* arr2, int m){
return sum;
}

// CHECK: void f4_grad_enzyme(double *arr1, int n, double *arr2, int m, double *_d_arr1, int *_d_n, double *_d_arr2, int *_d_m) {
// CHECK: void f4_grad_enzyme_0_2(double *arr1, int n, double *arr2, int m, double *_d_arr1, double *_d_arr2) {
// CHECK-NEXT: __enzyme_autodiff_f4(f4, arr1, _d_arr1, n, arr2, _d_arr2, m);
// CHECK-NEXT: }

Expand All @@ -57,7 +57,7 @@ double f5(double arr[], double x,int n,double y){
return res;
}

// CHECK: void f5_grad_enzyme(double arr[], double x, int n, double y, double *_d_arr, double *_d_x, int *_d_n, double *_d_y) {
// CHECK: void f5_grad_enzyme_0_1_3(double arr[], double x, int n, double y, double *_d_arr, double *_d_x, double *_d_y) {
// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_f5(f5, arr, _d_arr, x, n, y);
// CHECK-NEXT: *_d_x = grad.d_arr[0U];
// CHECK-NEXT: *_d_y = grad.d_arr[1U];
Expand All @@ -82,30 +82,27 @@ int main() {
auto f3_grad=clad::gradient<clad::opts::use_enzyme>(f3);
double f3_list[3]={3,4,5};
double f3_res[3]={0};
int f3_dn=0;
f3_grad.execute(f3_list,3,f3_res,&f3_dn);
printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n = %d\n",f3_res[0],f3_res[1],f3_res[2],f3_dn);
//CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n = 0
f3_grad.execute(f3_list,3,f3_res);
printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f\n",f3_res[0],f3_res[1],f3_res[2]);
//CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00

auto f4_grad=clad::gradient<clad::opts::use_enzyme>(f4);
double f4_list1[3]={3,4,5};
double f4_list2[2]={1,2};
double f4_res1[3]={0};
double f4_res2[2]={0};
int f4_dn1=0,f4_dn2=0;
f4_grad.execute(f4_list1,3,f4_list2,2,f4_res1,&f4_dn1,f4_res2,&f4_dn2);
printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d\n",f4_res1[0],f4_res1[1],f4_res1[2],f4_dn1);
//CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n1 = 0
printf("d_y1 = %.2f, d_y2 = %.2f, d_n2 = %d\n",f4_res2[0],f4_res2[1],f4_dn2);
//CHECK-EXEC: d_y1 = 2.00, d_y2 = 4.00, d_n2 = 0
f4_grad.execute(f4_list1,3,f4_list2,2,f4_res1,f4_res2);
printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f\n",f4_res1[0],f4_res1[1],f4_res1[2]);
//CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00
printf("d_y1 = %.2f, d_y2 = %.2f\n",f4_res2[0],f4_res2[1]);
//CHECK-EXEC: d_y1 = 2.00, d_y2 = 4.00

auto f5_grad=clad::gradient<clad::opts::use_enzyme>(f5);
double f5_list[3]={3,4,5};
double f5_res[3]={0};
double f5_x=10.0,f5_dx=0,f5_y=5,f5_dy;
int f5_dn=0;
f5_grad.execute(f5_list,f5_x,3,f5_y,f5_res,&f5_dx,&f5_dn,&f5_dy);
printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d, d_x = %.2f, d_y = %.2f\n",f5_res[0],f5_res[1],f5_res[2],f5_dn, f5_dx, f5_dy);
//CHECK-EXEC: d_x1 = 50.00, d_x2 = 50.00, d_x3 = 50.00, d_n1 = 0, d_x = 60.00, d_y = 120.00
f5_grad.execute(f5_list,f5_x,3,f5_y,f5_res,&f5_dx,&f5_dy);
printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_x = %.2f, d_y = %.2f\n",f5_res[0],f5_res[1],f5_res[2],f5_dx, f5_dy);
//CHECK-EXEC: d_x1 = 50.00, d_x2 = 50.00, d_x3 = 50.00, d_x = 60.00, d_y = 120.00

}

0 comments on commit 63d07c5

Please sign in to comment.