Skip to content

Commit

Permalink
Add support for computing only the diagonal hessian entries (vgvassil…
Browse files Browse the repository at this point in the history
…ev#950)

* Add support for computing only the diagonal hessian entries

fixes vgvassilev#509
  • Loading branch information
vaithak authored Jun 22, 2024
1 parent e37264d commit 1267d57
Show file tree
Hide file tree
Showing 14 changed files with 222 additions and 51 deletions.
1 change: 1 addition & 0 deletions benchmark/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ if (CLAD_ENABLE_ENZYME_BACKEND)
endif(CLAD_ENABLE_ENZYME_BACKEND)
CB_ADD_GBENCHMARK(VectorModeComparison VectorModeComparison.cpp)
CB_ADD_GBENCHMARK(MemoryComplexity MemoryComplexity.cpp)
CB_ADD_GBENCHMARK(Hessians Hessians.cpp)

set (CLAD_BENCHMARK_DEPS clad)
get_property(_benchmark_names DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY TESTS)
Expand Down
47 changes: 47 additions & 0 deletions benchmark/Hessians.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "benchmark/benchmark.h"

#include "clad/Differentiator/Differentiator.h"

#include "BenchmarkedFunctions.h"

// Benchmark Hessian diagonal sum computation, by computing the
// entire computation.
static void BM_HessianCompleteComputation(benchmark::State& state) {
auto dfdx2 = clad::hessian(weightedSum, "p[0:1],w[0:1]");
double p[] = {1, 2};
double w[] = {3, 4};
unsigned long long sum = 0;
double hessianMatrix[16] = {};
for (unsigned i = 0; i < 16; i++)
hessianMatrix[i] = 0.0;
for (auto _ : state) {
dfdx2.execute(p, w, 3, hessianMatrix);
for (int i = 0; i < 4; i++)
// Sum the diagonal of the Hessian matrix.
benchmark::DoNotOptimize(sum += hessianMatrix[i * 4 + i]);
}
}
BENCHMARK(BM_HessianCompleteComputation);

// Benchmark Hessian diagonal sum computation, by computing only
// the diagonal elements.
static void BM_HessianDiagonalComputation(benchmark::State& state) {
auto dfdx2 =
clad::hessian<clad::opts::diagonal_only>(weightedSum, "p[0:1],w[0:1]");
double p[] = {1, 2};
double w[] = {3, 4};
unsigned long long sum = 0;
double diagonalHessian[4] = {};
for (unsigned i = 0; i < 4; i++)
diagonalHessian[i] = 0.0;
for (auto _ : state) {
dfdx2.execute(p, w, 3, diagonalHessian);
for (int i = 0; i < 4; i++)
// Sum the diagonal of the Hessian matrix.
benchmark::DoNotOptimize(sum += diagonalHessian[i]);
}
}
BENCHMARK(BM_HessianDiagonalComputation);

// Define our main.
BENCHMARK_MAIN();
3 changes: 3 additions & 0 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ enum opts : unsigned {
// 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid
enable_tbr = 1 << (ORDER_BITS + 2),
disable_tbr = 1 << (ORDER_BITS + 3),

// Specifying whether we only want the diagonal of the hessian.
diagonal_only = 1 << (ORDER_BITS + 4),
}; // enum opts

constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) {
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enum class DiffMode {
experimental_vector_pushforward,
reverse,
hessian,
hessian_diagonal,
jacobian,
reverse_mode_forward_pass,
error_estimation
Expand All @@ -33,6 +34,8 @@ inline const char* DiffModeToString(DiffMode mode) {
return "reverse";
case DiffMode::hessian:
return "hessian";
case DiffMode::hessian_diagonal:
return "hessian_diagonal";
case DiffMode::jacobian:
return "jacobian";
case DiffMode::reverse_mode_forward_pass:
Expand Down
8 changes: 5 additions & 3 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ struct DiffRequest {
// A flag to enable the use of enzyme for backend instead of clad
bool use_enzyme = false;

/// A pointer to keep track of the prototype of the derived function.
/// This will be particularly useful for pushforward and pullback functions.
clang::FunctionDecl* DerivedFDPrototype = nullptr;
/// A pointer to keep track of the prototype of the derived functions.
/// For higher order derivatives, we store the entire sequence of
/// prototypes declared for all orders of derivatives.
/// This will be useful for forward declaration of the derived functions.
llvm::SmallVector<clang::FunctionDecl*, 2> DerivedFDPrototypes;

/// A boolean to indicate if only the declaration of the derived function
/// is required (and not the definition or body).
Expand Down
14 changes: 10 additions & 4 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,11 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

endScope(); // Function body scope

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
Expand Down Expand Up @@ -529,8 +532,11 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,

endScope(); // Function body scope

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}

m_Sema.PopFunctionScopeInfo();
Expand Down
20 changes: 15 additions & 5 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "clad/Differentiator/VectorForwardModeVisitor.h"
#include "clad/Differentiator/VectorPushForwardModeVisitor.h"

#include "llvm/Support/SaveAndRestore.h"

#include <algorithm>

#include "clad/Differentiator/CladUtils.h"
Expand Down Expand Up @@ -295,9 +297,17 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
FunctionDecl* derivative = this->FindDerivedFunction(request);
if (!derivative) {
alreadyDerived = false;
// Derive declaration of the the forward mode derivative.
request.DeclarationOnly = true;
derivative = plugin::ProcessDiffRequest(m_CladPlugin, request);

{
// Store and restore the original function and its order.
llvm::SaveAndRestore<const FunctionDecl*> origFn(request.Function);
llvm::SaveAndRestore<unsigned> origFnOrder(
request.CurrentDerivativeOrder);

// Derive declaration of the the forward mode derivative.
request.DeclarationOnly = true;
derivative = plugin::ProcessDiffRequest(m_CladPlugin, request);
}

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
Expand All @@ -309,7 +319,6 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
// Add the request to derive the definition of the forward mode derivative
// to the schedule.
request.DeclarationOnly = false;
request.DerivedFDPrototype = derivative;
}
this->AddEdgeToGraph(request, alreadyDerived);
return derivative;
Expand Down Expand Up @@ -423,7 +432,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this, request);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::hessian) {
} else if (request.Mode == DiffMode::hessian ||
request.Mode == DiffMode::hessian_diagonal) {
HessianModeVisitor H(*this, request);
result = H.Derive(FD, request);
} else if (request.Mode == DiffMode::jacobian) {
Expand Down
13 changes: 12 additions & 1 deletion lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,14 @@ namespace clad {
} else {
request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) {
if (!A->getAnnotation().equals("H")) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Diagonal only option is only valid for Hessian "
"mode.");
return true;
}
}
}

if (A->getAnnotation().equals("D")) {
Expand Down Expand Up @@ -651,7 +659,10 @@ namespace clad {
}
}
} else if (A->getAnnotation().equals("H")) {
request.Mode = DiffMode::hessian;
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only))
request.Mode = DiffMode::hessian_diagonal;
else
request.Mode = DiffMode::hessian;
} else if (A->getAnnotation().equals("J")) {
request.Mode = DiffMode::jacobian;
} else if (A->getAnnotation().equals("G")) {
Expand Down
109 changes: 78 additions & 31 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
return secondDerivative;
}

/// Derives the function two times with forward mode AD and returns the
/// FunctionDecl obtained.
static FunctionDecl* DeriveUsingForwardModeTwice(
Sema& SemaRef, clad::plugin::CladPlugin& CP,
clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest,
const Expr* ForwardModeArgs, DerivedFnCollector& DFC) {
// Set derivative order in the request to 2.
IndependentArgRequest.RequestedDerivativeOrder = 2;
IndependentArgRequest.Args = ForwardModeArgs;
IndependentArgRequest.Mode = DiffMode::forward;
IndependentArgRequest.CallUpdateRequired = false;
IndependentArgRequest.UpdateDiffParamsInfo(SemaRef);
// Derive the function twice in forward mode.
FunctionDecl* secondDerivative =
Builder.HandleNestedDiffRequest(IndependentArgRequest);
return secondDerivative;
}

DerivativeAndOverload
HessianModeVisitor::Derive(const clang::FunctionDecl* FD,
const DiffRequest& request) {
Expand All @@ -91,14 +109,16 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

std::vector<FunctionDecl*> secondDerivativeColumns;
std::vector<FunctionDecl*> secondDerivativeFuncs;
llvm::SmallVector<size_t, 16> IndependentArgsSize{};
size_t TotalIndependentArgsSize = 0;

// request.Function is original function passed in from clad::hessian
assert(m_DiffReq == request);

std::string hessianFuncName = request.BaseFunctionName + "_hessian";
if (request.Mode == DiffMode::hessian_diagonal)
hessianFuncName += "_diagonal";
// To be consistent with older tests, nothing is appended to 'f_hessian' if
// we differentiate w.r.t. all the parameters at once.
if (args.size() != FD->getNumParams() ||
Expand Down Expand Up @@ -192,27 +212,38 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
PVD->getNameAsString() + "[" + std::to_string(i) + "]";
auto ForwardModeIASL =
CreateStringLiteral(m_Context, independentArgString);
auto* DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeColumns.push_back(DFD);
FunctionDecl* DFD = nullptr;
if (request.Mode == DiffMode::hessian_diagonal)
DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder,
request, ForwardModeIASL,
m_Builder.m_DFC);
else
DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeFuncs.push_back(DFD);
}

} else {
IndependentArgsSize.push_back(1);
TotalIndependentArgsSize++;
// Derive the function w.r.t. to the current arg in forward mode and
// then in reverse mode w.r.t to all requested args
auto ForwardModeIASL =
CreateStringLiteral(m_Context, PVD->getNameAsString());
auto* DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeColumns.push_back(DFD);
FunctionDecl* DFD = nullptr;
if (request.Mode == DiffMode::hessian_diagonal)
DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder,
request, ForwardModeIASL,
m_Builder.m_DFC);
else
DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeFuncs.push_back(DFD);
}
}
}
return Merge(secondDerivativeColumns, IndependentArgsSize,
return Merge(secondDerivativeFuncs, IndependentArgsSize,
TotalIndependentArgsSize, hessianFuncName, DC,
hessianFunctionType, paramTypes);
}
Expand Down Expand Up @@ -272,14 +303,13 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
return VD;
});

// The output parameter "hessianMatrix".
// The output parameter "hessianMatrix" or "diagonalHessianVector"
std::string outputParamName = "hessianMatrix";
if (m_DiffReq.Mode == DiffMode::hessian_diagonal)
outputParamName = "diagonalHessianVector";
params.back() = ParmVarDecl::Create(
m_Context,
hessianFD,
noLoc,
noLoc,
&m_Context.Idents.get("hessianMatrix"),
paramTypes.back(),
m_Context, hessianFD, noLoc, noLoc,
&m_Context.Idents.get(outputParamName), paramTypes.back(),
m_Context.getTrivialTypeSourceInfo(paramTypes.back(), noLoc),
params.front()->getStorageClass(),
/* No default value */ nullptr);
Expand All @@ -301,7 +331,6 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
// Creates callExprs to the second derivative functions genereated
// and creates maps array elements to input array.
for (size_t i = 0, e = secDerivFuncs.size(); i < e; ++i) {
const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize;
auto size_type = m_Context.getSizeType();
auto size_type_bits = m_Context.getIntWidth(size_type);

Expand Down Expand Up @@ -345,22 +374,40 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
}
}

size_t columnIndex = 0;
// Create Expr parameters for each independent arg in the CallExpr
for (size_t indArgSize : IndependentArgsSize) {
llvm::APInt offsetValue(size_type_bits,
HessianMatrixStartIndex + columnIndex);
if (m_DiffReq.Mode == DiffMode::hessian_diagonal) {
const size_t HessianMatrixStartIndex = i;
// Call the derived function for second derivative.
Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams);

// Create the offset argument.
llvm::APInt offsetValue(size_type_bits, HessianMatrixStartIndex);
Expr* OffsetArg =
IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc);
// Create the hessianMatrix + OffsetArg expression.
Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg);

DeclRefToParams.push_back(SliceExpr);
columnIndex += indArgSize;
// Create a assignment expression to store the value of call expression
// into the diagonalHessianVector with index HessianMatrixStartIndex.
Expr* SliceExprLHS = BuildOp(BO_Add, m_Result, OffsetArg);
Expr* DerefExpr = BuildOp(UO_Deref, BuildParens(SliceExprLHS));
Expr* AssignExpr = BuildOp(BO_Assign, DerefExpr, call);
CompStmtSave.push_back(AssignExpr);
} else {
const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize;
size_t columnIndex = 0;
// Create Expr parameters for each independent arg in the CallExpr
for (size_t indArgSize : IndependentArgsSize) {
llvm::APInt offsetValue(size_type_bits,
HessianMatrixStartIndex + columnIndex);
// Create the offset argument.
Expr* OffsetArg =
IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc);
// Create the hessianMatrix + OffsetArg expression.
Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg);

DeclRefToParams.push_back(SliceExpr);
columnIndex += indArgSize;
}
Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams);
CompStmtSave.push_back(call);
}
Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams);
CompStmtSave.push_back(call);
}

auto StmtsRef =
Expand Down
7 changes: 5 additions & 2 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,11 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
m_Derivative->setBody(fnBody);
endScope();

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
Expand Down
Loading

0 comments on commit 1267d57

Please sign in to comment.