Skip to content

Commit

Permalink
Use one IsDifferentiableType for all modes and move it to VisitorBase
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Apr 22, 2024
1 parent c993a64 commit 1c2dc1e
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 33 deletions.
2 changes: 0 additions & 2 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class BaseForwardModeVisitor

virtual void ExecuteInsidePushforwardFunctionBlock();

static bool IsDifferentiableType(clang::QualType T);

virtual StmtDiff
VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
Expand Down
2 changes: 0 additions & 2 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,6 @@ namespace clad {
bool IsMemoryFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);

bool IsDifferentiableType(clang::QualType QT);

/// Removes the local const qualifiers from a QualType and returns a new
/// type.
clang::QualType getNonConstType(clang::QualType T, clang::ASTContext& C,
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ namespace clad {
return QT->isArrayType() || QT->isPointerType();
}

static bool IsDifferentiableType(clang::QualType T);

clang::CompoundStmt* MakeCompoundStmt(const Stmts& Stmts);

/// Get the latest block of code (i.e. place for statements output).
Expand Down
23 changes: 4 additions & 19 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,6 @@ BaseForwardModeVisitor::BaseForwardModeVisitor(DerivativeBuilder& builder)

BaseForwardModeVisitor::~BaseForwardModeVisitor() {}

bool BaseForwardModeVisitor::IsDifferentiableType(QualType T) {
QualType origType = T;
// FIXME: arbitrary dimension array type as well.
while (utils::isArrayOrPointerType(T))
T = utils::GetValueType(T);
T = T.getNonReferenceType();
if (T->isEnumeralType())
return false;
if (T->isRealType() || T->isStructureOrClassType())
return true;
if (origType->isPointerType() && T->isVoidType())
return true;
return false;
}

bool IsRealNonReferenceType(QualType T) {
return T.getNonReferenceType()->isRealType();
}
Expand Down Expand Up @@ -224,7 +209,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
// non-reference type for creating the derivatives.
QualType dParamType = param->getType().getNonReferenceType();
// We do not create derived variable for array/pointer parameters.
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) ||
if (!IsDifferentiableType(dParamType) ||
utils::isArrayOrPointerType(dParamType))
continue;
Expr* dParam = nullptr;
Expand Down Expand Up @@ -420,7 +405,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
for (auto* PVD : m_Function->parameters()) {
paramTypes.push_back(PVD->getType());

if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
if (IsDifferentiableType(PVD->getType()))
derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType()));
}

Expand Down Expand Up @@ -485,7 +470,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
if (identifierMissing)
m_DeclReplacements[PVD] = newPVD;

if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
if (!IsDifferentiableType(PVD->getType()))
continue;
auto derivedPVDName = "_d_" + std::string(PVDII->getName());
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName);
Expand Down Expand Up @@ -1069,7 +1054,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
}
}
CallArgs.push_back(argDiff.getExpr());
if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) {
if (IsDifferentiableType(arg->getType())) {
Expr* dArg = argDiff.getExpr_dx();
// FIXME: What happens when dArg is nullptr?
diffArgs.push_back(dArg);
Expand Down
5 changes: 0 additions & 5 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,11 +685,6 @@ namespace clad {
#endif
}

bool IsDifferentiableType(clang::QualType QT) {
// FIXME: consider analysing object types with this
return !utils::GetValueType(QT)->isIntegerType();
}

clang::QualType getNonConstType(clang::QualType T, clang::ASTContext& C,
clang::Sema& S) {
clang::Qualifiers quals(T.getQualifiers());
Expand Down
8 changes: 4 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: we have to create adjoints for all parameters when any
// external sources are enabled because gradient overloads don't support
// additional parameters.
if (utils::IsDifferentiableType(dParam.param->getType()) || m_ExternalSource)
if (IsDifferentiableType(dParam.param->getType()) || m_ExternalSource)
args.push_back(dParam.param);
}
else
Expand Down Expand Up @@ -604,7 +604,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: we have to create adjoints for all parameters when any
// external sources are enabled because gradient overloads don't support
// additional parameters.
if (!utils::IsDifferentiableType(param->getType()) && !m_ExternalSource)
if (!IsDifferentiableType(param->getType()) && !m_ExternalSource)
continue;
// derived variables are already created for independent variables.
if (m_Variables.count(param))
Expand Down Expand Up @@ -1548,7 +1548,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// because the derivatives of arguments passed by reference are directly
// modified by the derived callee function.
// Also, no need to create adjoint variables for non-differentiable types.
if (utils::IsReferenceOrPointerArg(arg) || !utils::IsDifferentiableType(arg->getType())) {
if (utils::IsReferenceOrPointerArg(arg) || !IsDifferentiableType(arg->getType())) {
argDiff = Visit(arg);
CallArgDx.push_back(argDiff.getExpr_dx());
} else {
Expand Down Expand Up @@ -2585,7 +2585,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Integer types are not differentiable,
// no need to construct an adjoint.
if (!utils::IsDifferentiableType(VD->getType())) {
if (!IsDifferentiableType(VD->getType())) {
Expr* init = nullptr;
if (VD->getInit())
init = Visit(VD->getInit()).getExpr();
Expand Down
19 changes: 18 additions & 1 deletion lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ namespace clad {
m_Mode != DiffMode::experimental_pushforward)
for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) {
QualType paramTy = originalFD->getParamDecl(i)->getType();
if (!utils::IsDifferentiableType(paramTy)) {
if (!IsDifferentiableType(paramTy)) {
QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema);
VarDecl* argDecl = BuildVarDecl(argTy, "_r", getZeroInit(argTy));
Expr* arg = BuildDeclRef(argDecl);
Expand Down Expand Up @@ -965,4 +965,21 @@ namespace clad {
}
return false;
}



bool VisitorBase::IsDifferentiableType(QualType T) {
QualType origType = T;
// FIXME: arbitrary dimension array type as well.
while (utils::isArrayOrPointerType(T))
T = utils::GetValueType(T);
T = T.getNonReferenceType();
if (T->isEnumeralType())
return false;
if (T->isFloatingType() || T->isStructureOrClassType())
return true;
if (origType->isPointerType() && T->isVoidType())
return true;
return false;
}
} // end namespace clad

0 comments on commit 1c2dc1e

Please sign in to comment.