From da55781b61e34362f1366c0221a232fec6a50744 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Sat, 14 Dec 2024 15:25:55 +0000 Subject: [PATCH] Constify interfaces. NFC --- include/clad/Differentiator/CladUtils.h | 2 +- include/clad/Differentiator/DerivativeBuilder.h | 6 +++--- lib/Differentiator/BaseForwardModeVisitor.cpp | 8 +++----- lib/Differentiator/CladUtils.cpp | 6 +++--- lib/Differentiator/DerivativeBuilder.cpp | 7 ++++--- lib/Differentiator/ReverseModeVisitor.cpp | 11 ++++------- 6 files changed, 18 insertions(+), 22 deletions(-) diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 071a0e516..349c0a87e 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -126,7 +126,7 @@ namespace clad { /// such declaration context is found, then returns `nullptr`. clang::DeclContext* FindDeclContext(clang::Sema& semaRef, clang::DeclContext* DC1, - clang::DeclContext* DC2); + const clang::DeclContext* DC2); /// Finds the qualified name `name` in the declaration context `DC`. /// diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 5e9d54ac2..7c9d843d7 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -117,7 +117,7 @@ namespace clad { /// null otherwise. clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff( const std::string& Name, llvm::SmallVectorImpl& CallArgs, - clang::Scope* S, clang::DeclContext* originalFnDC, + clang::Scope* S, const clang::DeclContext* originalFnDC, bool forCustomDerv = true, bool namespaceShouldExist = true, clang::Expr* CUDAExecConfig = nullptr); bool noOverloadExists(clang::Expr* UnresolvedLookup, @@ -150,7 +150,7 @@ namespace clad { /// \returns The lookup result of the custom derivative or numerical /// differentiation function. clang::LookupResult LookupCustomDerivativeOrNumericalDiff( - const std::string& Name, clang::DeclContext* originalFnDC, + const std::string& Name, const clang::DeclContext* originalFnDC, clang::CXXScopeSpec& SS, bool forCustomDerv = true, bool namespaceShouldExist = true); @@ -160,7 +160,7 @@ namespace clad { /// \returns The custom derivative function if found, nullptr otherwise. clang::FunctionDecl* LookupCustomDerivativeDecl(const std::string& Name, - clang::DeclContext* originalFnDC, + const clang::DeclContext* originalFnDC, clang::QualType functionType); public: diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 4a9c81412..382ba00db 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1227,7 +1227,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix(); callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, customDerivativeArgs, getCurrentScope(), - const_cast(FD->getDeclContext())); + FD->getDeclContext()); // Custom derivative templates can be written in a // general way that works for both vectorized and non-vectorized // modes. We have to also look for the pushforward with the regular name. @@ -1236,7 +1236,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { clad::utils::ComputeEffectiveFnName(FD) + "_pushforward"; callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, customDerivativeArgs, getCurrentScope(), - const_cast(FD->getDeclContext())); + FD->getDeclContext()); } if (!isLambda) { // Check if it is a recursive call. @@ -2315,11 +2315,9 @@ clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall( std::string customPushforwardName = clad::utils::ComputeEffectiveFnName(CE->getConstructor()) + GetPushForwardFunctionSuffix(); - // FIXME: We should not use const_cast to get the decl context here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) Expr* pushforwardCall = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforwardName, customPushforwardArgs, getCurrentScope(), - const_cast(CE->getConstructor()->getDeclContext())); + CE->getConstructor()->getDeclContext()); return pushforwardCall; } } // end namespace clad diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index fc3bb0b02..19d0c3172 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -187,8 +187,8 @@ namespace clad { } DeclContext* FindDeclContext(clang::Sema& semaRef, clang::DeclContext* DC1, - clang::DeclContext* DC2) { - llvm::SmallVector contexts; + const clang::DeclContext* DC2) { + llvm::SmallVector contexts; assert((isa(DC1) || isa(DC1)) && "DC1 can only be extended if it is a " "namespace or translation unit decl."); @@ -240,7 +240,7 @@ namespace clad { } DeclContext* DC = DC1; for (int i = contexts.size() - 1; i >= 0; --i) { - NamespaceDecl* ND = cast(contexts[i]); + const auto* ND = cast(contexts[i]); if (ND->getIdentifier()) DC = LookupNSD(semaRef, ND->getIdentifier()->getName(), /*shouldExist=*/false, DC1); diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 465f35aee..859a52b18 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -30,6 +30,7 @@ #include "llvm/Support/SaveAndRestore.h" #include +#include #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/Compatibility.h" @@ -167,7 +168,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { } LookupResult DerivativeBuilder::LookupCustomDerivativeOrNumericalDiff( - const std::string& Name, clang::DeclContext* originalFnDC, + const std::string& Name, const clang::DeclContext* originalFnDC, CXXScopeSpec& SS, bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) { @@ -222,7 +223,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { } FunctionDecl* DerivativeBuilder::LookupCustomDerivativeDecl( - const std::string& Name, clang::DeclContext* originalFnDC, + const std::string& Name, const clang::DeclContext* originalFnDC, QualType functionType) { CXXScopeSpec SS; LookupResult R = @@ -246,7 +247,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( const std::string& Name, llvm::SmallVectorImpl& CallArgs, - clang::Scope* S, clang::DeclContext* originalFnDC, + clang::Scope* S, const clang::DeclContext* originalFnDC, bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/, Expr* CUDAExecConfig /*=nullptr*/) { CXXScopeSpec SS; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index a7d4dc6cb..e5188ac91 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -266,7 +266,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, originalFnType->getExtProtoInfo()); // Check if the function is already declared as a custom derivative. - // FIXME: We should not use const_cast to get the decl context here. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( @@ -1828,7 +1827,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, pushforwardCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext()), + FD->getDeclContext(), /*forCustomDerv=*/true, /*namespaceShouldExist=*/true, CUDAExecConfig); if (OverloadedDerivedFn) @@ -1932,7 +1931,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPullback, pullbackCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext()), + FD->getDeclContext(), /*forCustomDerv=*/true, /*namespaceShouldExist=*/true, CUDAExecConfig); if (baseDiff.getExpr()) @@ -4248,8 +4247,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (Expr* customPullbackCall = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPullbackName, pullbackArgs, getCurrentScope(), - const_cast( - CE->getConstructor()->getDeclContext()))) { + CE->getConstructor()->getDeclContext())) { curRevBlock.insert(it, customPullbackCall); if (m_TrackConstructorPullbackInfo) { setConstructorPullbackCallInfo(llvm::cast(customPullbackCall), @@ -4585,8 +4583,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, args.append(derivedArgs.begin(), derivedArgs.end()); Expr* customForwPassCE = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - forwPassFnName, args, getCurrentScope(), - const_cast(FD->getDeclContext())); + forwPassFnName, args, getCurrentScope(), FD->getDeclContext()); return customForwPassCE; }