Skip to content

Commit

Permalink
Delay the differentiation process until the end of TU.
Browse files Browse the repository at this point in the history
Before this patch clad attaches itself as a first consumer and applies AD before
code generation. However, that is limited since clang sends every top-level
declaration to codegen which limits the amount of flexibility clad has. For
example, we have to instantiate all pending templates at every
HandleTopLevelDecl calls; we cannot really differentiate virtual functions
whose classes have sent their key function to CodeGen; and in general we perform
actions which are semantically useful for the end of the translation unit.

This patch makes clad a single consumer of clang which dispatches to the others.
That's done by delaying all calls to the consumers until the end of the TU where
clad can replay the exact sequence of calls to the other consumers as if they
were directly connected to the frontend.

Fixes #248
  • Loading branch information
vgvassilev committed Mar 14, 2024
1 parent 8a77f81 commit d18b840
Show file tree
Hide file tree
Showing 9 changed files with 520 additions and 117 deletions.
30 changes: 15 additions & 15 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
gradient(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code);
}

/// Specialization for differentiating functors.
Expand All @@ -384,8 +384,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
gradient(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code, f);
}

/// Generates function which computes hessian matrix of the given function wrt
Expand All @@ -404,9 +404,9 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
hessian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code);
}

/// Specialization for differentiating functors.
Expand All @@ -421,8 +421,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
hessian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code, f);
}

/// Generates function which computes jacobian matrix of the given function
Expand All @@ -441,9 +441,9 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
jacobian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code);
}

/// Specialization for differentiating functors.
Expand All @@ -458,8 +458,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
jacobian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code, f);
}

template <typename ArgSpec = const char*, typename F,
Expand Down
29 changes: 29 additions & 0 deletions include/clad/Differentiator/Sins.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef CLAD_DIFFERENTIATOR_SINS_H
#define CLAD_DIFFERENTIATOR_SINS_H

#include <type_traits>

/// Standard-protected facility allowing access into private members in C++.
/// Use with caution!
// NOLINTBEGIN(cppcoreguidelines-macro-usage)
#define CONCATE_(X, Y) X##Y
#define CONCATE(X, Y) CONCATE_(X, Y)
#define ALLOW_ACCESS(CLASS, MEMBER, ...) \
template <typename Only, __VA_ARGS__ CLASS::*Member> \
struct CONCATE(MEMBER, __LINE__) { \
friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \
}; \
template <typename> struct Only_##MEMBER; \
template <> struct Only_##MEMBER<CLASS> { \
friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER<CLASS>*); \
}; \
template struct CONCATE(MEMBER, \
__LINE__)<Only_##MEMBER<CLASS>, &CLASS::MEMBER>

#define ACCESS(OBJECT, MEMBER) \
(OBJECT).*Access((Only_##MEMBER< \
std::remove_reference<decltype(OBJECT)>::type>*)nullptr)

// NOLINTEND(cppcoreguidelines-macro-usage)

#endif // CLAD_DIFFERENTIATOR_SINS_H
37 changes: 5 additions & 32 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/CladUtils.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -59,42 +60,14 @@ namespace clad {
return true;
}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};

// Tag used to access Sema::CurScope.
using namespace clang;
struct Sema_CurScope : TagBase<Sema_CurScope, Scope * Sema::*> {};
template struct Rob<Sema_CurScope, &Sema::CurScope>;
} // namespace
ALLOW_ACCESS(Sema, CurScope, Scope*);

clang::Scope*& VisitorBase::getCurrentScope() {
return m_Sema.*get(Sema_CurScope());
return ACCESS(m_Sema, CurScope);
}

void VisitorBase::setCurrentScope(clang::Scope* S) {
m_Sema.*get(Sema_CurScope()) = S;
getCurrentScope() = S;
assert(getEnclosingNamespaceOrTUScope() && "Lost path to base.");
}

Expand Down
8 changes: 8 additions & 0 deletions test/FirstDerivative/CodeGenSimple.C
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...);

int f_1_darg0(int x);

double sq_defined_later(double);

int main() {
int x = 4;
clad::differentiate(f_1, 0);
auto df = clad::differentiate(sq_defined_later, "x");
printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2
printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6
return 0;
}

double sq_defined_later(double x) {
return x * x;
}
81 changes: 81 additions & 0 deletions test/Misc/ClangConsumers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out \
// RUN: -fms-compatibility -DMS_COMPAT -std=c++14 -fmodules \
// RUN: -Xclang -print-stats 2>&1 | FileCheck %s
// CHECK-NOT: {{.*error|warning|note:.*}}
//
// RUN: clang -xc -Xclang -add-plugin -Xclang clad -Xclang -load \
// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \
// RUN: -Xclang -debug-info-kind=limited -Xclang -triple -Xclang bpf-linux-gnu \
// RUN: -S -emit-llvm -Xclang -target-cpu -Xclang generic \
// RUN: -Xclang -print-stats 2>&1 | \
// RUN: FileCheck -check-prefix=CHECK_C %s
// CHECK_C-NOT: {{.*error|warning|note:.*}}
// XFAIL: clang-7, clang-8, clang-9, target={{i586.*}}, target=arm64-apple-{{.*}}
//
// RUN: clang -xobjective-c -Xclang -add-plugin -Xclang clad -Xclang -load \
// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \
// RUN: -Xclang -print-stats 2>&1 | \
// RUN: FileCheck -check-prefix=CHECK_OBJC %s
// CHECK_OBJC-NOT: {{.*error|warning|note:.*}}

#ifdef __cplusplus

#pragma clang module build N
module N {}
#pragma clang module contents
#pragma clang module begin N
struct f { void operator()() const {} };
template <typename T> auto vtemplate = f{};
#pragma clang module end
#pragma clang module endbuild

#pragma clang module import N

#ifdef MS_COMPAT
class __single_inheritance IncSingle;
#endif // MS_COMPAT

struct V { virtual int f(); };
int V::f() { return 1; }
template <typename T> T f() { return T(); }
int i = f<int>();

// Check if shouldSkipFunctionBody is called.
// RUN: %cladclang -I%S/../../include -fsyntax-only -fmodules \
// RUN: -Xclang -code-completion-at=%s:%(line-1):1 %s -o - | \
// RUN: FileCheck -check-prefix=CHECK-CODECOMP %s
// CHECK-CODECOMP: COMPLETION

// CHECK: HandleImplicitImportDecl
// CHECK: AssignInheritanceModel
// CHECK: HandleTopLevelDecl
// CHECK: HandleCXXImplicitFunctionInstantiation
// CHECK: HandleInterestingDecl
// CHECK: HandleVTable
// CHECK: HandleCXXStaticMemberVarInstantiation

#endif // __cplusplus

#ifdef __STDC_VERSION__ // C mode
int i;

extern char ch;
int test(void) { return ch; }
char ch = 1;

// CHECK_C: CompleteTentativeDefinition
// CHECK_C: CompleteExternalDeclaration
#endif // __STDC_VERSION__

#ifdef __OBJC__
@interface I
void f();
@end
// CHECK_OBJC: HandleTopLevelDeclInObjCContainer
#endif // __OBJC__

int main() {
#ifdef __cplusplus
vtemplate<int>();
#endif // __cplusplus
}
3 changes: 3 additions & 0 deletions test/lit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ if.*\[ ?(llvm[^ ]*) ([^ ]*) ?\].*{
if platform.system() not in ['Windows'] or lit_config.getBashPath() != '':
config.available_features.add('shell')


config.available_features.add("clang-{0}".format(config.clang_version_major))

# Loadable module
# FIXME: This should be supplied by Makefile or autoconf.
#if sys.platform in ['win32', 'cygwin']:
Expand Down
1 change: 1 addition & 0 deletions test/lit.site.cfg.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import sys
## Autogenerated by LLVM/clad configuration.
# Do not edit!
llvm_version_major = @LLVM_VERSION_MAJOR@
config.clang_version_major = @CLANG_VERSION_MAJOR@
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
Expand Down
Loading

0 comments on commit d18b840

Please sign in to comment.