From 9f245ce4b15f03e6f9d819b541620a9b5ddc024e Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Sun, 22 Dec 2024 16:47:43 +0000 Subject: [PATCH] Improve DiffRequest and DynamicGraph printing. --- include/clad/Differentiator/DiffPlanner.h | 19 +++++++++++-------- include/clad/Differentiator/Differentiator.h | 4 +--- include/clad/Differentiator/DynamicGraph.h | 20 ++++++++++++-------- lib/Differentiator/DiffPlanner.cpp | 15 +++++++++++++++ test/Misc/TimingsReport.C | 11 +++++++++++ tools/ClangPlugin.cpp | 3 ++- unittests/Misc/DynamicGraph.cpp | 3 ++- 7 files changed, 54 insertions(+), 21 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 30b483b7e..54c4ee3eb 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -1,14 +1,17 @@ #ifndef CLAD_DIFF_PLANNER_H #define CLAD_DIFF_PLANNER_H -#include "clang/AST/RecursiveASTVisitor.h" -#include "llvm/ADT/SmallSet.h" #include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" +#include "clang/AST/RecursiveASTVisitor.h" + +#include "llvm/Support/raw_ostream.h" + #include #include + namespace clang { class CallExpr; class CompilerInstance; @@ -132,15 +135,15 @@ struct DiffRequest { const clang::FunctionDecl* operator->() const { return Function; } - // String operator for printing the node. operator std::string() const { - std::string res = BaseFunctionName + "__order_" + - std::to_string(CurrentDerivativeOrder) + "__mode_" + - DiffModeToString(Mode); - if (EnableTBRAnalysis) - res += "__TBR"; + std::string res; + llvm::raw_string_ostream s(res); + print(s); + s.flush(); return res; } + void print(llvm::raw_ostream& Out) const; + void dump() const { print(llvm::errs()); } bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index c8aaaa286..83a430dad 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -14,14 +14,12 @@ #include "BuiltinDerivativesCUDA.cuh" #endif #include "CladConfig.h" -#include "DynamicGraph.h" #include "FunctionTraits.h" #include "Matrix.h" #include "NumericalDiff.h" #include "Tape.h" -#include -#include +#include #include namespace clad { diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index 2ef8cf992..dc58a0292 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -109,23 +110,26 @@ template class DynamicGraph { const std::vector& getNodes() const { return m_nodes; } std::vector& getNodes() { return m_nodes; } + /// Dump the nodes and edges. + void dump() const { print(std::cerr); } + /// Print the nodes and edges in the graph. - void print() { + void print(std::ostream& Out) const { // First print the nodes with their insertion order. for (const T& node : m_nodes) { - std::pair nodeInfo = m_nodeMap[node]; - std::cout << (std::string)node << ": #" << nodeInfo.second; + std::pair nodeInfo = m_nodeMap.at(node); + Out << (std::string)node << ": #" << nodeInfo.second; if (m_sources.find(nodeInfo.second) != m_sources.end()) - std::cout << " (source)"; + Out << " (source)"; if (nodeInfo.first) - std::cout << ", (done)\n"; + Out << ", (done)\n"; else - std::cout << ", (unprocessed)\n"; + Out << ", (unprocessed)\n"; } // Then print the edges. for (int i = 0; i < m_nodes.size(); i++) - for (size_t dest : m_adjList[i]) - std::cout << i << " -> " << dest << "\n"; + for (size_t dest : m_adjList.at(i)) + Out << i << " -> " << dest << "\n"; } /// Get the next node to be processed from the queue of nodes to be diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index d2c39c1d9..cd90e0c01 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -1,5 +1,7 @@ #include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/DiffMode.h" + #include "ActivityAnalyzer.h" #include "TBRAnalyzer.h" @@ -601,6 +603,19 @@ namespace clad { return; } + void DiffRequest::print(llvm::raw_ostream& Out) const { + Out << '<'; + PrintingPolicy Policy(Function->getASTContext().getLangOpts()); + Function->getNameForDiagnostic(Out, Policy, /*Qualified=*/true); + Out << ">[name=" << BaseFunctionName << ", " + << "order=" << CurrentDerivativeOrder << ", " + << "mode=" << DiffModeToString(Mode); + if (EnableTBRAnalysis) + Out << ", tbr"; + Out << ']'; + Out.flush(); + } + bool DiffRequest::shouldBeRecorded(Expr* E) const { if (!EnableTBRAnalysis) return true; diff --git a/test/Misc/TimingsReport.C b/test/Misc/TimingsReport.C index 6b8dc9282..7e0988139 100644 --- a/test/Misc/TimingsReport.C +++ b/test/Misc/TimingsReport.C @@ -1,7 +1,18 @@ // RUN: %cladclang %s -I%S/../../include -oTimingsReport.out -ftime-report 2>&1 | %filecheck %s +// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -print-stats 2>&1 | %filecheck -check-prefix=CHECK_STATS %s +// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -plugin-arg-clad -Xclang -enable-tbr -Xclang -print-stats 2>&1 | %filecheck -check-prefix=CHECK_STATS_TBR %s #include "clad/Differentiator/Differentiator.h" // CHECK: Timers for Clad Funcs +// CHECK_STATS: *** INFORMATION ABOUT THE DIFF REQUESTS +// CHECK_STATS-NEXT: [name=test1, order=1, mode=forward]: #0 (source), (done) +// CHECK_STATS-NEXT: [name=test2, order=1, mode=reverse]: #1 (source), (done) +// CHECK_STATS-NEXT: [name=nested1, order=1, mode=pushforward]: #2, (done) +// CHECK_STATS-NEXT: [name=nested2, order=1, mode=pullback]: #3, (done) +// CHECK_STATS-NEXT: 0 -> 2 +// CHECK_STATS-NEXT: 1 -> 3 + +// CHECK_STATS_TBR: [name=test1, order=1, mode=forward, tbr]: #0 (source), (done) double nested1(double c){ return c*3*c; diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 9e1977e0b..ee11c04bf 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -30,6 +30,7 @@ #include "clad/Differentiator/Compatibility.h" #include +#include // for std::cerr using namespace clang; @@ -526,7 +527,7 @@ namespace clad { // Print the graph of the diff requests. llvm::errs() << "\n*** INFORMATION ABOUT THE DIFF REQUESTS\n"; - m_DiffRequestGraph.print(); + m_DiffRequestGraph.print(std::cerr); m_Multiplexer->PrintStats(); } diff --git a/unittests/Misc/DynamicGraph.cpp b/unittests/Misc/DynamicGraph.cpp index 6954a6698..04d8f90fa 100644 --- a/unittests/Misc/DynamicGraph.cpp +++ b/unittests/Misc/DynamicGraph.cpp @@ -1,3 +1,4 @@ +#include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/Differentiator.h" #include @@ -46,7 +47,7 @@ TEST(DynamicGraphTest, Printing) { std::stringstream ss; std::streambuf* coutbuf = std::cout.rdbuf(); std::cout.rdbuf(ss.rdbuf()); - G.print(); + G.print(std::cout); std::cout.rdbuf(coutbuf); std::string expectedOutput = "node0: #0 (source), (unprocessed)\n" "node1: #1, (unprocessed)\n"