Skip to content

Commit

Permalink
Add bitmasked-option for tbr analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Mar 12, 2024
1 parent d7e5434 commit db242df
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 88 deletions.
23 changes: 18 additions & 5 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,36 @@ enum order {
third = 3,
}; // enum order

enum opts {
enum opts : unsigned {
use_enzyme = 1 << ORDER_BITS,
vector_mode = 1 << (ORDER_BITS + 1),

// Storing two bits for tbr analysis.
// 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid
enable_tbr = 1 << (ORDER_BITS + 2),
disable_tbr = 1 << (ORDER_BITS + 3),
}; // enum opts

constexpr unsigned GetDerivativeOrder(unsigned const bitmasked_opts) {
// Define overload for bitwise operations on clad::opts
constexpr opts operator|(opts a, opts b) {
return static_cast<opts>(static_cast<unsigned>(a) | static_cast<unsigned>(b));
}
constexpr opts operator&(opts a, opts b) {
return static_cast<opts>(static_cast<unsigned>(a) & static_cast<unsigned>(b));
}

constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) {
return bitmasked_opts & ORDER_MASK;
}

constexpr bool HasOption(unsigned const bitmasked_opts, unsigned const option) {
constexpr bool HasOption(const unsigned bitmasked_opts, const unsigned option) {
return (bitmasked_opts & option) == option;
}

constexpr unsigned GetBitmaskedOpts() { return 0; }
constexpr unsigned GetBitmaskedOpts(unsigned const first) { return first; }
constexpr unsigned GetBitmaskedOpts(const unsigned first) { return first; }
template <typename... Opts>
constexpr unsigned GetBitmaskedOpts(unsigned const first, Opts... opts) {
constexpr unsigned GetBitmaskedOpts(const unsigned first, Opts... opts) {
return first | GetBitmaskedOpts(opts...);
}

Expand Down
10 changes: 9 additions & 1 deletion include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ namespace clad {
using DiffSchedule = llvm::SmallVector<DiffRequest, 16>;
using DiffInterval = std::vector<clang::SourceRange>;

struct RequestOptions {
/// This is a flag to indicate the default behaviour to enable/disable
/// TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
};

class DiffCollector: public clang::RecursiveASTVisitor<DiffCollector> {
/// The source interval where clad was activated.
///
Expand All @@ -101,9 +107,11 @@ namespace clad {
const clang::FunctionDecl* m_TopMostFD = nullptr;
clang::Sema& m_Sema;

RequestOptions& m_Options;

public:
DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval,
DiffSchedule& plans, clang::Sema& S);
DiffSchedule& plans, clang::Sema& S, RequestOptions& opts);
bool VisitCallExpr(clang::CallExpr* E);

private:
Expand Down
60 changes: 25 additions & 35 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// function to differentiate \param[in] args independent parameter
/// information \returns `CladFunction` object to access the corresponding
/// derived function.
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
template <clad::opts... BitMaskedOpts, typename ArgSpec = const char*,
typename F,
typename DerivedFnType = ExtractDerivedFnTraitsForwMode_t<F>,
typename = typename std::enable_if<
Expand All @@ -311,7 +311,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
template <clad::opts... BitMaskedOpts, typename ArgSpec = const char*,
typename F,
typename DerivedFnType = ExtractDerivedFnTraitsForwMode_t<F>,
typename = typename std::enable_if<
Expand All @@ -334,7 +334,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
template <clad::opts... BitMaskedOpts, typename ArgSpec = const char*,
typename F,
typename DerivedFnType = ExtractDerivedFnTraitsVecForwMode_t<F>,
typename = typename std::enable_if<
Expand All @@ -358,9 +358,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <unsigned... BitMaskedOpts /*To check for enzyme*/,
typename ArgSpec = const char*, typename F,
typename DerivedFnType = GradientDerivedFnTraits_t<F>,
template <clad::opts... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = GradientDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
Expand All @@ -376,9 +375,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <unsigned... BitMaskedOpts /*To check for enzyme*/,
typename ArgSpec = const char*, typename F,
typename DerivedFnType = GradientDerivedFnTraits_t<F>,
template <clad::opts... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = GradientDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
Expand All @@ -397,38 +395,34 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = HessianDerivedFnTraits_t<F>,
template <clad::opts... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = HessianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("H")))
hessian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by hessian*/,
code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code);
}

/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = HessianDerivedFnTraits_t<F>,
template <clad::opts... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = HessianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("H")))
hessian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by hessian*/,
code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code, f);
}

/// Generates function which computes jacobian matrix of the given function
Expand All @@ -438,38 +432,34 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
template <clad::opts... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("J")))
jacobian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by Jacobian*/,
code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code);
}

/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
template <clad::opts... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("J")))
jacobian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by Jacobian*/,
code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code, f);
}

template <typename ArgSpec = const char*, typename F,
Expand Down
56 changes: 36 additions & 20 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ namespace clad {
}

DiffCollector::DiffCollector(DeclGroupRef DGR, DiffInterval& Interval,
DiffSchedule& plans, clang::Sema& S)
DiffSchedule& plans, clang::Sema& S,
RequestOptions& opts)
: m_Interval(Interval), m_DiffPlans(plans), m_TopMostFD(nullptr),
m_Sema(S) {
m_Sema(S), m_Options(opts) {

if (Interval.empty())
return;
Expand Down Expand Up @@ -556,27 +557,53 @@ namespace clad {
return true;
DiffRequest request{};

if (A->getAnnotation().equals("D")) {
request.Mode = DiffMode::forward;

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
bool enable_tbr_in_req = false;
bool disable_tbr_in_req = false;
if (!A->getAnnotation().equals("E") &&
FD->getTemplateSpecializationArgs()) {
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();

// Set option for TBR analysis.
enable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr);
disable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr);
if (enable_tbr_in_req && disable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable TBR options are specified.");
return true;
}
if (enable_tbr_in_req || disable_tbr_in_req) {
// override the default value of TBR analysis.
request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req;
} else {
request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis;
}
}

if (A->getAnnotation().equals("D")) {
request.Mode = DiffMode::forward;
unsigned derivative_order =
clad::GetDerivativeOrder(bitmasked_opts_value);
if (derivative_order == 0) {
derivative_order = 1; // default to first order derivative.
}
request.RequestedDerivativeOrder = derivative_order;
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) {
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
if (enable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"TBR analysis is not meant for forward mode AD.");
return true;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
request.Mode = DiffMode::vector_forward_mode;
Expand All @@ -601,17 +628,6 @@ namespace clad {
request.Mode = DiffMode::jacobian;
} else if (A->getAnnotation().equals("G")) {
request.Mode = DiffMode::reverse;

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
// reverse vector mode is not yet supported.
Expand Down
4 changes: 2 additions & 2 deletions test/Analyses/TBR.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang -mllvm -debug-only=clad-tbr -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s
// RUN: %cladclang -mllvm -debug-only=clad-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s
// REQUIRES: asserts
//CHECK-NOT: {{.*error|warning|note:.*}}

Expand All @@ -13,7 +13,7 @@ double f1(double x) {

#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient(F);\
auto F##grad = clad::gradient<clad::opts::enable_tbr>(F);\
F##grad.execute(x, result);\
printf("{%.2f}\n", result[0]); \
}
Expand Down
3 changes: 2 additions & 1 deletion test/FirstDerivative/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ int main () {
clad::differentiate(test_6, "x");
clad::differentiate(test_7, "i");
clad::differentiate(test_8, "x");

clad::differentiate<clad::opts::enable_tbr>(test_8); // expected-error {{TBR analysis is not meant for forward mode AD.}}
clad::differentiate<clad::opts::enable_tbr, clad::opts::disable_tbr>(test_8); // expected-error {{Both enable and disable TBR options are specified.}}
return 0;
}
2 changes: 1 addition & 1 deletion test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ int main() {
d_structPointer.execute(5, &d_x);
printf("%.2f\n", d_x); // CHECK-EXEC: 1.00

auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x");
auto d_cStyleMemoryAlloc = clad::gradient<clad::opts::disable_tbr>(cStyleMemoryAlloc, "x");
d_x = 0;
d_cStyleMemoryAlloc.execute(5, 7, &d_x);
printf("%.2f\n", d_x); // CHECK-EXEC: 4.00
Expand Down
7 changes: 7 additions & 0 deletions test/Misc/Args.C
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
// CHECK_HELP-NEXT: -fdump-derived-fn
// CHECK_HELP-NEXT: -fdump-derived-fn-ast
// CHECK_HELP-NEXT: -fgenerate-source-file
// CHECK_HELP-NEXT: -fno-validate-clang-version
// CHECK_HELP-NEXT: -enable-tbr
// CHECK_HELP-NEXT: -disable-tbr
// CHECK_HELP-NEXT: -fcustom-estimation-model
// CHECK_HELP-NEXT: -fprint-num-diff-errors
// CHECK_HELP-NEXT: -help
Expand All @@ -23,3 +26,7 @@
// RUN: -Xclang %t.so %S/../../demos/ErrorEstimation/CustomModel/test.cpp \
// RUN: -I%S/../../include 2>&1 | FileCheck --check-prefix=CHECK_SO_INVALID %s
// CHECK_SO_INVALID: Failed to load '{{.*.so}}', {{.*}}. Aborting.

// RUN: clang -fsyntax-only -fplugin=%cladlib -Xclang -plugin-arg-clad -Xclang -enable-tbr \
// RUN: -Xclang -plugin-arg-clad -Xclang -disable-tbr %s 2>&1 | FileCheck --check-prefix=CHECK_TBR %s
// CHECK_TBR: -enable-tbr and -disable-tbr cannot be used together
Loading

0 comments on commit db242df

Please sign in to comment.