Skip to content

Commit

Permalink
Improve DiffRequest and DynamicGraph printing.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 23, 2024
1 parent 36dbc6e commit 1a55b91
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 21 deletions.
19 changes: 11 additions & 8 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#ifndef CLAD_DIFF_PLANNER_H
#define CLAD_DIFF_PLANNER_H

#include "clang/AST/RecursiveASTVisitor.h"
#include "llvm/ADT/SmallSet.h"
#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include "clang/AST/RecursiveASTVisitor.h"

#include "llvm/Support/raw_ostream.h"

#include <iterator>
#include <set>

namespace clang {
class CallExpr;
class CompilerInstance;
Expand Down Expand Up @@ -132,15 +135,15 @@ struct DiffRequest {

const clang::FunctionDecl* operator->() const { return Function; }

// String operator for printing the node.
operator std::string() const {
std::string res = BaseFunctionName + "__order_" +
std::to_string(CurrentDerivativeOrder) + "__mode_" +
DiffModeToString(Mode);
if (EnableTBRAnalysis)
res += "__TBR";
std::string res;
llvm::raw_string_ostream s(res);
print(s);
s.flush();
return res;
}
void print(llvm::raw_ostream& Out) const;
void dump() const { print(llvm::errs()); }

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;
Expand Down
4 changes: 1 addition & 3 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
#include "BuiltinDerivativesCUDA.cuh"
#endif
#include "CladConfig.h"
#include "DynamicGraph.h"
#include "FunctionTraits.h"
#include "Matrix.h"
#include "NumericalDiff.h"
#include "Tape.h"

#include <assert.h>
#include <stddef.h>
#include <array>
#include <cstring>

namespace clad {
Expand Down
20 changes: 12 additions & 8 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <iostream>
#include <queue>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -109,23 +110,26 @@ template <typename T> class DynamicGraph {
const std::vector<T>& getNodes() const { return m_nodes; }
std::vector<T>& getNodes() { return m_nodes; }

/// Dump the nodes and edges.
void dump() const { print(std::cerr); }

/// Print the nodes and edges in the graph.
void print() {
void print(std::ostream& Out) const {
// First print the nodes with their insertion order.
for (const T& node : m_nodes) {
std::pair<bool, int> nodeInfo = m_nodeMap[node];
std::cout << (std::string)node << ": #" << nodeInfo.second;
std::pair<bool, int> nodeInfo = m_nodeMap.at(node);
Out << (std::string)node << ": #" << nodeInfo.second;
if (m_sources.find(nodeInfo.second) != m_sources.end())
std::cout << " (source)";
Out << " (source)";
if (nodeInfo.first)
std::cout << ", (done)\n";
Out << ", (done)\n";
else
std::cout << ", (unprocessed)\n";
Out << ", (unprocessed)\n";

Check warning on line 127 in include/clad/Differentiator/DynamicGraph.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/DynamicGraph.h#L127

Added line #L127 was not covered by tests
}
// Then print the edges.
for (int i = 0; i < m_nodes.size(); i++)
for (size_t dest : m_adjList[i])
std::cout << i << " -> " << dest << "\n";
for (size_t dest : m_adjList.at(i))
Out << i << " -> " << dest << "\n";
}

/// Get the next node to be processed from the queue of nodes to be
Expand Down
15 changes: 15 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "clad/Differentiator/DiffPlanner.h"

#include "clad/Differentiator/DiffMode.h"

#include "ActivityAnalyzer.h"
#include "TBRAnalyzer.h"

Expand Down Expand Up @@ -601,6 +603,19 @@ namespace clad {
return;
}

void DiffRequest::print(llvm::raw_ostream& Out) const {
Out << '<';
PrintingPolicy Policy(Function->getASTContext().getLangOpts());
Function->getNameForDiagnostic(Out, Policy, /*Qualified=*/true);
Out << ">[name=" << BaseFunctionName << ", "
<< "order=" << CurrentDerivativeOrder << ", "
<< "mode=" << DiffModeToString(Mode);
if (EnableTBRAnalysis)
Out << ", tbr";
Out << ']';
Out.flush();
}

bool DiffRequest::shouldBeRecorded(Expr* E) const {
if (!EnableTBRAnalysis)
return true;
Expand Down
11 changes: 11 additions & 0 deletions test/Misc/TimingsReport.C
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
// RUN: %cladclang %s -I%S/../../include -oTimingsReport.out -ftime-report 2>&1 | %filecheck %s
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -print-stats 2>&1 | %filecheck -check-prefix=CHECK_STATS %s
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -plugin-arg-clad -Xclang -enable-tbr -Xclang -print-stats 2>&1 | %filecheck -check-prefix=CHECK_STATS_TBR %s

#include "clad/Differentiator/Differentiator.h"
// CHECK: Timers for Clad Funcs
// CHECK_STATS: *** INFORMATION ABOUT THE DIFF REQUESTS
// CHECK_STATS-NEXT: <test1>[name=test1, order=1, mode=forward]: #0 (source), (done)
// CHECK_STATS-NEXT: <test2>[name=test2, order=1, mode=reverse]: #1 (source), (done)
// CHECK_STATS-NEXT: <nested1>[name=nested1, order=1, mode=pushforward]: #2, (done)
// CHECK_STATS-NEXT: <nested2>[name=nested2, order=1, mode=pullback]: #3, (done)
// CHECK_STATS-NEXT: 0 -> 2
// CHECK_STATS-NEXT: 1 -> 3

// CHECK_STATS_TBR: <test1>[name=test1, order=1, mode=forward, tbr]: #0 (source), (done)

double nested1(double c){
return c*3*c;
Expand Down
3 changes: 2 additions & 1 deletion tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "clad/Differentiator/Compatibility.h"

#include <algorithm>
#include <iostream> // for std::cerr

using namespace clang;

Expand Down Expand Up @@ -526,7 +527,7 @@ namespace clad {

// Print the graph of the diff requests.
llvm::errs() << "\n*** INFORMATION ABOUT THE DIFF REQUESTS\n";
m_DiffRequestGraph.print();
m_DiffRequestGraph.print(std::cerr);

m_Multiplexer->PrintStats();
}
Expand Down
3 changes: 2 additions & 1 deletion unittests/Misc/DynamicGraph.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/Differentiator.h"

#include <iostream>
Expand Down Expand Up @@ -46,7 +47,7 @@ TEST(DynamicGraphTest, Printing) {
std::stringstream ss;
std::streambuf* coutbuf = std::cout.rdbuf();
std::cout.rdbuf(ss.rdbuf());
G.print();
G.print(std::cout);
std::cout.rdbuf(coutbuf);
std::string expectedOutput = "node0: #0 (source), (unprocessed)\n"
"node1: #1, (unprocessed)\n"
Expand Down

0 comments on commit 1a55b91

Please sign in to comment.