forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_task.h
229 lines (195 loc) · 9.16 KB
/
graph_task.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
#pragma once
#include <ATen/ThreadLocalState.h>
#include <ATen/core/Tensor.h>
#include <c10/util/ThreadLocal.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/autograd/utils/warnings.h>
#include <vector>
namespace torch::autograd {
using edge_list = std::vector<Edge>;
struct ReadyQueue;
static constexpr int NO_DEVICE = -2;
static constexpr int CPU_DEVICE = -1;
// GraphTask holds metadata needed for a single execution of backward()
struct GraphTask : std::enable_shared_from_this<GraphTask> {
std::atomic<uint64_t> outstanding_tasks_{0};
// Indicates if an error occurred while executing any task. When this is
// true, it signals all threads to stop executing.
std::atomic_bool has_error_{false};
std::atomic_bool future_completed_{false};
// It is safe to read keep_graph_ without synchronization
bool keep_graph_;
// To protect reads/writes to not_ready_, dependencies_, captured_vars_,
// has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
std::mutex mutex_;
std::unordered_map<Node*, InputBuffer> not_ready_;
std::unordered_map<Node*, int> dependencies_;
// Records the nodes that are in the graph
std::unordered_set<Node*> nodes_in_graph_;
c10::SmallVector<Node*, 4> graph_roots_;
// Note [Exec info]
// Exec info is created for each GraphTask, which allows filtering paths on
// the graph that are not needed. It has a bit complicated semantics. If it's
// empty, it means the task is run in a "default" mode, which means that all
// next_edges we encounter should get executed. If it's not empty, only
// functions that have an entry and this entry has needed == True should be
// executed. exec_info is only empty when the graph is executed via
// .backward() and the inputs parameter is not passed. Otherwise, when
// executed through .grad(), or when inputs arg is specified for .backward(),
// exec_info will be non-empty.
//
struct ExecInfo {
struct Capture {
Capture(const Capture&) = delete;
Capture(Capture&&) = default;
Capture& operator=(const Capture&) = delete;
Capture& operator=(Capture&&) = default;
~Capture() = default;
Capture(int input_idx, int output_idx)
: input_idx_(input_idx), output_idx_(output_idx) {}
int input_idx_; // within Node inputs
int output_idx_; // within the output vector of a GraphTask
// This hook will be executed after a grad is captured. The captured
// grad will be replaced by the return value of the hook.
struct GradCaptureHook {
virtual ~GradCaptureHook() = default;
virtual at::Tensor operator()(const at::Tensor& grad) = 0;
};
// NOTE [Deprecated capture hooks]
//
// The current status of capture hooks is that we continue to support
// the single usage of it by distributed in the dist_engine. If anyone
// else needs to use it for other purposes, they should file an issue.
//
// Capture hooks were originally created because there did not exist
// any way to register pre/post hooks to grad_fn in a way such that it
// would still be executed even if that is the grad_fn of a Tensor
// passed as input= of .grad. As far as I know, only dist_engine uses
// this hook.
//
// However, there are other alternatives today like tensor hooks that can
// replace the usage that originally motivated its creation. Also,
// Captures hooks are an outlier in terms of the types of hook that
// autograd offers in how it is registered and behaves, e.g. it is a hook
// registered not to the graph, but to a particular graph_task! This makes
// it a burden to maintain.
//
// It would be very nice to clean up/do a migration from pre/post
// hooks used in distributed to use tensor hooks, but for now we just
// mark this method as deprecated to prevent additional usage.
//
// If you still think you really need to capture hooks, please file an
// issue (and tag autograd).
const std::vector<std::unique_ptr<GradCaptureHook>>&
DO_NOT_USE_DEPRECATED_get_capture_hooks() const {
return hooks_;
}
// See NOTE [deprecated capture hooks]
void DO_NOT_USE_DEPRECATED_register_capture_hook(
std::unique_ptr<GradCaptureHook> hook) {
hooks_.push_back(std::move(hook));
}
private:
// The hooks will be called one by one in the order as they were added.
// The input grad of a hook will be the output of its preceding hook. The
// first hook will take the captured grad as the input. The output of the
// last hook will replace the captured grad.
std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
};
bool should_execute() const {
return needed_ || captures_;
}
bool needed_ = false;
std::unique_ptr<std::vector<Capture>> captures_;
};
// exec_info_ is safe to read without synchronization
std::unordered_map<Node*, ExecInfo> exec_info_;
// Captures variables are grads captured that we return to the user. After
// execution of the GraphTask is completed, the captured_vars_ are moved
// out of the GraphTask and are no longer valid.
std::vector<Variable> captured_vars_;
// Note: this field is not ready to be used until the proper
// `thread_locals_.set_grad_mode()` call in the constructor.
at::ThreadLocalState thread_locals_ = at::ThreadLocalState();
std::unordered_set<c10::Stream> leaf_streams;
// Per-device current streams of the execute() that called this GraphTask.
// These will be synced with leaf_streams in exec_post_processing.
std::vector<std::optional<c10::Stream>> caller_current_streams_;
// Collects caller_current_streams_ for the accelerator device.
void stash_current_streams();
void init_to_execute(
Node& graph_root,
const edge_list& outputs,
bool accumulate_grad,
uint64_t min_topo_nr);
// The value of worker_device in the thread that created this task.
// See Note [Reentrant backwards]
// Safe to read owner_ and reentrant_depth_ without synchronization
int owner_;
// The number of parent graph tasks for this graph task
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const int reentrant_depth_;
bool can_checkpoint() const {
return exec_info_.empty();
}
// check if the GraphTask is completed or not
bool completed();
// mark the graph task as completed and trigger post processing
void mark_as_completed_and_run_post_processing();
// Set an appropriate exception on this graph_task which was encountered while
// running the provided function.
void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);
// Set an appropriate exception on this graph_task which was encountered while
// running the provided function. But doesn't signal completion on
// 'future_result_' right away. The user needs to explicitly mark
// 'future_result_' completed with an appropriate exception.
void set_exception_without_signal(const std::shared_ptr<Node>& fn);
// Whether or not to stop execution for this GraphTask when an error is
// encountered. When set to true, this would cause Engine::execute() to throw
// an exception as soon as the autograd engine receives an exception.
bool exit_on_error_;
// CPU threads are dedicated to processing CPU work for the backward they
// invoked. So any given graph task maintains its own cpu_ready_queue_ where
// you should send work for it to be done. We memoize the cpu_ready_queue_ per
// GraphTask so that we know which ready queue we should push to if we are on
// device thread (i.e. GPU) and but next NodeTask should be run on CPU.
std::shared_ptr<ReadyQueue> cpu_ready_queue_;
// Future representing the completion of the graph task. Notified when all
// tasks are done.
c10::intrusive_ptr<at::ivalue::Future> future_result_;
// Final callbacks installed during execution of this GraphTask
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_. Intentionally no reusing
// mutex_ as the two are protecting different data structures.
std::mutex final_callbacks_lock_;
utils::DelayWarningHandler warning_handler_;
uint64_t id_;
GraphTask(
bool keep_graph,
bool grad_mode,
int reentrant_depth,
std::shared_ptr<ReadyQueue> cpu_ready_queue,
c10::SmallVector<Node*, 4> graph_roots,
bool exit_on_error = false);
private:
// run GraphTask post processing
void exec_post_processing();
};
// The guard that sets and restores current_graph_task.
class GraphTaskGuard {
public:
explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task);
~GraphTaskGuard();
void restore_current_graph_task();
private:
std::shared_ptr<GraphTask> last_graph_task_;
};
TORCH_API const std::unordered_map<Node*, GraphTask::ExecInfo>*
get_current_graph_task_exec_info();
TORCH_API const std::unordered_set<Node*>*
get_current_graph_task_nodes_in_graph();
TORCH_API bool get_current_graph_task_keep_graph();
TORCH_API std::vector<Node*> get_current_graph_task_execution_order();
TORCH_API int get_current_graph_task_id();
void add_node_to_current_graph_task_exec_info(Node* fn);
} // namespace torch::autograd