Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't create adjoints for integral type variables #864

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demos/ErrorEstimation/FloatSum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ int main() {
finalError = 0;
unsigned int dn = 0;
// First execute the derived function.
df.execute(x, n, &ret[0], &dn, finalError);
df.execute(x, n, &ret[0], &dn, &finalError);

double kahanResult = kahanSum(x, n);
double vanillaResult = vanillaSum(x, n);
Expand Down
2 changes: 1 addition & 1 deletion demos/ErrorEstimation/PrintModel/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ int main() {
// Calculate the error
float dx, dy;
double error;
df.execute(2, 3, &dx, &dy, error);
df.execute(2, 3, &dx, &dy, &error);
}
2 changes: 0 additions & 2 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class BaseForwardModeVisitor

virtual void ExecuteInsidePushforwardFunctionBlock();

static bool IsDifferentiableType(clang::QualType T);

virtual StmtDiff
VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
Expand Down
15 changes: 7 additions & 8 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ namespace custom_derivatives {
#ifdef __CUDACC__
template <typename T>
ValueAndPushforward<cudaError_t, cudaError_t>
cudaMalloc_pushforward(T** devPtr, size_t sz, T** d_devPtr, size_t d_sz)
cudaMalloc_pushforward(T** devPtr, size_t sz, T** d_devPtr)
__attribute__((host)) {
return {cudaMalloc(devPtr, sz), cudaMalloc(d_devPtr, sz)};
}

ValueAndPushforward<cudaError_t, cudaError_t>
cudaMemcpy_pushforward(void* destPtr, void* srcPtr, size_t count,
cudaMemcpyKind kind, void* d_destPtr, void* d_srcPtr,
size_t d_count) __attribute__((host)) {
cudaMemcpyKind kind, void* d_destPtr, void* d_srcPtr)
__attribute__((host)) {
return {cudaMemcpy(destPtr, srcPtr, count, kind),
cudaMemcpy(d_destPtr, d_srcPtr, count, kind)};
}
Expand Down Expand Up @@ -199,18 +199,17 @@ CUDA_HOST_DEVICE void clamp_pullback(const T& v, const T& lo, const T& hi,

// NOLINTBEGIN(cppcoreguidelines-no-malloc)
// NOLINTBEGIN(cppcoreguidelines-owning-memory)
inline ValueAndPushforward<void*, void*> malloc_pushforward(size_t sz,
size_t d_sz) {
inline ValueAndPushforward<void*, void*> malloc_pushforward(size_t sz) {
return {malloc(sz), malloc(sz)};
}

inline ValueAndPushforward<void*, void*>
calloc_pushforward(size_t n, size_t sz, size_t d_n, size_t d_sz) {
inline ValueAndPushforward<void*, void*> calloc_pushforward(size_t n,
size_t sz) {
return {calloc(n, sz), calloc(n, sz)};
}

inline ValueAndPushforward<void*, void*>
realloc_pushforward(void* ptr, size_t sz, void* d_ptr, size_t d_sz) {
realloc_pushforward(void* ptr, size_t sz, void* d_ptr) {
return {realloc(ptr, sz), realloc(d_ptr, sz)};
}

Expand Down
5 changes: 5 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ namespace clad {

bool IsMemoryFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);

/// Removes the local const qualifiers from a QualType and returns a new
/// type.
clang::QualType getNonConstType(clang::QualType T, clang::ASTContext& C,
clang::Sema& S);
} // namespace utils
} // namespace clad

Expand Down
20 changes: 0 additions & 20 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,6 @@ namespace clad {
clang::SourceLocation& noLoc,
clang::DeclarationNameInfo name,
clang::QualType functionType);
/// Looks for a suitable overload for a given function.
///
/// \param[in] Name The identification information of the function
/// overload to be found.
/// \param[in] CallArgs The call args to be used to resolve to the
/// correct overload.
/// \param[in] forCustomDerv A flag to keep track of which
/// namespace we should look in for the overloads.
/// \param[in] namespaceShouldExist A flag to enforce assertion failure
/// if the overload function namespace was not found. If false and
/// the function containing namespace was not found, nullptr is returned.
///
/// \returns The call expression if a suitable function overload was found,
/// null otherwise.
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv = true, bool namespaceShouldExist = true);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
/// Shorthand to issues a warning or error.
template <std::size_t N>
void diag(clang::DiagnosticsEngine::Level level, // Warning or Error
Expand Down
10 changes: 5 additions & 5 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {

template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = GradientDerivedEstFnTraits_t<F>>
CladFunction<DerivedFnType> __attribute__((annotate("E")))
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
annotate("E")))
estimate_error(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<
DerivedFnType>(derivedFn /* will be replaced by estimation code*/,
code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by estimation code*/, code);
}

// Gradient Structure for Reverse Mode Enzyme
Expand Down
7 changes: 3 additions & 4 deletions include/clad/Differentiator/FunctionTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,7 @@ namespace clad {
// GradientDerivedEstFnTraits specializations for pure function pointer types
template <class ReturnType, class... Args>
struct GradientDerivedEstFnTraits<ReturnType (*)(Args...)> {
using type = void (*)(Args..., OutputParamType_t<Args, Args>...,
double&);
using type = void (*)(Args..., OutputParamType_t<Args, void>..., void*);
};

/// These macro expansions are used to cover all possible cases of
Expand All @@ -498,8 +497,8 @@ namespace clad {
#define GradientDerivedEstFnTraits_AddSPECS(var, cv, vol, ref, noex) \
template <typename R, typename C, typename... Args> \
struct GradientDerivedEstFnTraits<R (C::*)(Args...) cv vol ref noex> { \
using type = void (C::*)(Args..., OutputParamType_t<Args, Args>..., \
double&) cv vol ref noex; \
using type = void (C::*)(Args..., OutputParamType_t<Args, void>..., \
void*) cv vol ref noex; \
};

#if __cpp_noexcept_function_type > 0
Expand Down
17 changes: 6 additions & 11 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,6 @@ namespace clad {
return "_grad";
}

/// Removes the local const qualifiers from a QualType and returns a new
/// type.
static clang::QualType
getNonConstType(clang::QualType T, clang::ASTContext& C, clang::Sema& S) {
clang::Qualifiers quals(T.getQualifiers());
quals.removeConst();
return S.BuildQualifiedType(T.getUnqualifiedType(), noLoc, quals);
}
// Function to Differentiate with Clad as Backend
void DifferentiateWithClad();

Expand Down Expand Up @@ -197,8 +189,9 @@ namespace clad {
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit) {
assert(E && "cannot infer type from null expression");
return StoreAndRef(E, getNonConstType(E->getType(), m_Context, m_Sema), d,
prefix, forceDeclCreation, IS);
return StoreAndRef(
E, clad::utils::getNonConstType(E->getType(), m_Context, m_Sema), d,
prefix, forceDeclCreation, IS);
}

/// An overload allowing to specify the type for the variable.
Expand Down Expand Up @@ -443,7 +436,9 @@ namespace clad {
/// Builds an overload for the gradient function that has derived params for
/// all the arguments of the requested function and it calls the original
/// gradient function internally
clang::FunctionDecl* CreateGradientOverload();
/// \param[in] numExtraParam The number of extra parameters requested by an
/// external source (e.g. the final error in error estimation).
clang::FunctionDecl* CreateGradientOverload(unsigned numExtraParam = 0);

/// Returns the type that should be used to represent the derivative of a
/// variable of type `yType` with respect to a parameter variable of type
Expand Down
24 changes: 24 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace clad {

#include "Compatibility.h"
#include "DerivativeBuilder.h"
#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffMode.h"

#include "clang/AST/RecursiveASTVisitor.h"
Expand Down Expand Up @@ -207,6 +208,8 @@ namespace clad {
return QT->isArrayType() || QT->isPointerType();
}

static bool IsDifferentiableType(clang::QualType T);

clang::CompoundStmt* MakeCompoundStmt(const Stmts& Stmts);

/// Get the latest block of code (i.e. place for statements output).
Expand Down Expand Up @@ -646,6 +649,27 @@ namespace clad {
void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR);
/// Looks for a suitable overload for a given function.
///
/// \param[in] Name The identification information of the function
/// overload to be found.
/// \param[in] CallArgs The call args to be used to resolve to the
/// correct overload.
/// \param[in] forCustomDerv A flag to keep track of which
/// namespace we should look in for the overloads.
/// \param[in] namespaceShouldExist A flag to enforce assertion failure
/// if the overload function namespace was not found. If false and
/// the function containing namespace was not found, nullptr is returned.
///
/// \returns The call expression if a suitable function overload was found,
/// null otherwise.
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, const clang::FunctionDecl* originalFD,
bool forCustomDerv = true, bool namespaceShouldExist = true,
llvm::SmallVectorImpl<clang::Stmt*>* block = nullptr);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
};
} // end namespace clad

Expand Down
57 changes: 24 additions & 33 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,6 @@ BaseForwardModeVisitor::BaseForwardModeVisitor(DerivativeBuilder& builder)

BaseForwardModeVisitor::~BaseForwardModeVisitor() {}

bool BaseForwardModeVisitor::IsDifferentiableType(QualType T) {
QualType origType = T;
// FIXME: arbitrary dimension array type as well.
while (utils::isArrayOrPointerType(T))
T = utils::GetValueType(T);
T = T.getNonReferenceType();
if (T->isEnumeralType())
return false;
if (T->isRealType() || T->isStructureOrClassType())
return true;
if (origType->isPointerType() && T->isVoidType())
return true;
return false;
}

bool IsRealNonReferenceType(QualType T) {
return T.getNonReferenceType()->isRealType();
}
Expand Down Expand Up @@ -224,7 +209,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
// non-reference type for creating the derivatives.
QualType dParamType = param->getType().getNonReferenceType();
// We do not create derived variable for array/pointer parameters.
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) ||
if (!IsDifferentiableType(dParamType) ||
utils::isArrayOrPointerType(dParamType))
continue;
Expr* dParam = nullptr;
Expand Down Expand Up @@ -420,7 +405,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
for (auto* PVD : m_Function->parameters()) {
paramTypes.push_back(PVD->getType());

if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
if (IsDifferentiableType(PVD->getType()))
derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType()));
}

Expand Down Expand Up @@ -485,7 +470,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
if (identifierMissing)
m_DeclReplacements[PVD] = newPVD;

if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
if (!IsDifferentiableType(PVD->getType()))
continue;
auto derivedPVDName = "_d_" + std::string(PVDII->getName());
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName);
Expand Down Expand Up @@ -1069,7 +1054,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
}
}
CallArgs.push_back(argDiff.getExpr());
if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) {
if (IsDifferentiableType(arg->getType())) {
Expr* dArg = argDiff.getExpr_dx();
// FIXME: What happens when dArg is nullptr?
diffArgs.push_back(dArg);
Expand All @@ -1094,9 +1079,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// Try to find a user-defined overloaded derivative.
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
Expr* callDiff = BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(), FD);

// Check if it is a recursive call.
if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) {
Expand Down Expand Up @@ -1188,7 +1172,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
}
}
}

// If clad failed to derive it, try finding its derivative using
// numerical diff.
if (!callDiff) {
Expand Down Expand Up @@ -1314,9 +1297,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
opDiff = BuildOp(opCode, derivedL, derivedR);
} else if (BinOp->isAssignmentOp()) {
if (Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
// If the LHS has a non-differentiable type, Ldiff.getExpr_dx() will be 0.
// Don't create a warning then.
if (IsDifferentiableType(BinOp->getLHS()->getType()))
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
opDiff = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
} else if (opCode == BO_Assign || opCode == BO_AddAssign ||
opCode == BO_SubAssign) {
Expand Down Expand Up @@ -1393,10 +1379,13 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
BuildVarDecl(VD->getType(), VD->getNameAsString(), initDiff.getExpr(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
// FIXME: Create unique identifier for derivative.
VarDecl* VDDerived = BuildVarDecl(
VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
VarDecl* VDDerived = nullptr;
if (IsDifferentiableType(VD->getType())) {
VDDerived = BuildVarDecl(VD->getType(), "_d_" + VD->getNameAsString(),
initDiff.getExpr_dx(), VD->isDirectInit(), nullptr,
VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
}
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

Expand Down Expand Up @@ -1458,7 +1447,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName())
m_DeclReplacements[VD] = VDDiff.getDecl();
decls.push_back(VDDiff.getDecl());
declsDiff.push_back(VDDiff.getDecl_dx());
if (VDDiff.getDecl_dx())
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
Expand Down Expand Up @@ -1597,7 +1587,7 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {
// ...
// ...
// }
if (condVarClone) {
if (condVarRes.getDecl_dx()) {
bodyResult = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), cast<CompoundStmt>(bodyResult),
BuildDeclStmt(condVarRes.getDecl_dx()));
Expand Down Expand Up @@ -1676,7 +1666,8 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
if (condVarDecl) {
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
if (condVarDeclDiff.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
}

StmtDiff initVarRes = (SS->getInit() ? Visit(SS->getInit()) : StmtDiff());
Expand Down
7 changes: 7 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,5 +684,12 @@ namespace clad {
return FD->getNameAsString() == "free";
#endif
}

clang::QualType getNonConstType(clang::QualType T, clang::ASTContext& C,
clang::Sema& S) {
clang::Qualifiers quals(T.getQualifiers());
quals.removeConst();
return S.BuildQualifiedType(T.getUnqualifiedType(), noLoc, quals);
}
} // namespace utils
} // namespace clad
Loading
Loading