forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine.cpp
1368 lines (1226 loc) · 53.4 KB
/
engine.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/DeviceGuard.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Parallel.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/util/Exception.h>
#include <c10/core/Stream.h>
#include <c10/core/Event.h>
#include <c10/core/DeviceGuard.h>
#include <c10/util/irange.h>
#include <c10/util/Optional.h>
#include <c10/util/ThreadLocal.h>
#include <c10/core/StreamGuard.h>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <cstdint>
#include <functional>
#include <iostream>
#include <memory>
#include <mutex>
#include <set>
#include <string>
#include <thread>
#include <unordered_set>
#include <typeinfo>
#include <sstream>
#include <queue>
#include <TH/TH.h>
namespace torch { namespace autograd {
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static bool in_bad_autograd_fork =
false; // True for children forked after engine's thread pool init
// Called in the forked child if engine's thread pool has already been
// initialized
static void forked_autograd_child() { in_bad_autograd_fork = true; }
// Should be called before unsafe for forks (thread pool) calls
static void track_bad_autograd_forks() {
#ifndef WIN32
static std::once_flag flag;
std::call_once(
flag, [&] { pthread_atfork(nullptr, nullptr, forked_autograd_child); });
#endif
}
}
// Threads spawned by the engine are assigned a 'worker_device' specifying
// what device they process work for. This variable is initialized at:
// 1. thread creation time for CUDA, XLA device threads, as they are
// spinning threads waiting for works on their device.
// 2. before the graph task execution for CPU threads, as for each
// backward call we use the caller thread to drive engine execution.
// This is used when handling reentrant backwards calls;
// See Note [Reentrant backwards]
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static thread_local int worker_device = NO_DEVICE;
// This variable is true if ALL invocations in the stack of re-entrant engine
// invocations are imperative backwards. This special variable is needed for the
// gradient checkpointing feature only.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static thread_local bool checkpoint_valid = true;
// Number of nested reentrant backwards calls currently on this thread
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static thread_local int current_depth = 0;
// For all device threads (i.e. CUDA, XLA), total_depth represents the total nested
// reentrant backwards depths over all device threads.
// For CPU devices, it is the total depth associated with the original backward call.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static thread_local int total_depth = 0;
// The current GraphTask being executed by this thread. This helps
// queue_callback() to find the target GraphTask to append final callbacks.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
C10_DEFINE_TLS_static(std::shared_ptr<GraphTask>, tls_current_graph_task);
#define current_graph_task (tls_current_graph_task.get())
// Every autograd worker thread is associated with a ready queue, which specifies
// the stream of work of this thread to do. This shared_ptr is a thread_local
// pointer to each thread's ready_queue, and it should be initialized via the
// Engine::init_local_ready_queue() call in each corresponding thread before execution.
//
// The CUDA, XLA threads are shared among all invocations of backwards via
// device_ready_queues_, while CPU threads are dedicated to processing CPU work for
// the backward they invoked. So any given graph task maintains its own cpu_ready_queue_
// where you should send work for it to be done
//
// For reentrant backward calls, if we spawn new thread from the current thread
// because we reached the maximum depth, the new thread will just reuse the same
// ReadyQueue with the parent thread for performance improvement.
// see Note [Reentrant backwards] for more details.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
C10_DEFINE_TLS_static(std::shared_ptr<ReadyQueue>, tls_local_ready_queue);
#define local_ready_queue (tls_local_ready_queue.get())
// Note [Reentrant backwards]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
// To understand the reentrant backwards problem, we have to notice two
// aspects of how the autograd engine is implemented today:
//
// 1. When you call Engine::execute(), you want to block until
// differentiation finishes so that you can get the final result variables
// of the backwards pass.
//
// 2. The engine operates by having a single worker thread per work queue,
// and every work queue is pinned to a specific device where the
// operation is executed.
//
// The problem is, suppose that you call backward() inside of a worker
// thread. By property (1), we're supposed to block until the nested task
// finishes. However, by property (2), this worker thread is on the
// hook for processing the tasks assigned to it; we better not block,
// because then all of our backward executions (including the one we
// just started) will deadlock!
//
// We maintain a pool of threads waiting for work to do
// When a reentrant backwards call occurs, the current thread blocks
// and a thread from the pool is woken up to complete the blocking tasks and an
// any other tasks that would have been assigned to that worker. If there are no
// threads available, a new thread is spawned. The new thread will continue
// processing tasks from the same ReadyQueue as the parent worker
//
// When the GraphTask is finished, the parent worker thread that is waiting on
// the task is notified and the current thread returns to the pool.
// Note [Streaming backwards]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
// On CUDA devices the autograd engine's device operations are run on the
// same stream that ran them in forward. This requires automatically
// syncing the streams so that function A finishes producing its
// output before function B consumes it.
//
// This synchronization occurs when outputs are placed into input buffers.
// The functions corresponding to input buffer positions have metadata
// recording their streams from forward, and during backward this
// data is used to sync the producer's stream with the consumer's.
//
// When a CUDA function is run either all its inputs were accumulated on the
// stream used to run the function OR the inputs are on different devices
// and the function is responsible for properly acquiring them.
//
// User-facing stream semantics of a backward() (or torch.autograd.grad())
// call with respect to surrounding ops are the same as for any other call.
// See "Stream semantics of backward passes" on
// https://pytorch.org/docs/stable/notes/cuda.html
//
// Internally, backward() runs ops (including leaf nodes) on side threads.
// And streams are thread local. So GraphTask achieves the above semantics by
// 1. remembering the current streams on all active CUDA devices
// in the user-facing thread (aka, the thread that called execute() to
// launch the GraphTask)
// 2. remembering the "leaf streams" (streams each backward leaf node ran on)
// 3. during exec_post_processing, for each leaf stream, sync the remembered
// current streams (on the leaf stream's device) with that
// leaf stream.
int NodeTask::getReentrantDepth() const {
std::shared_ptr<GraphTask> graph_task = base_.lock();
if (graph_task) {
return graph_task->reentrant_depth_;
} else {
// The graph task is no longer valid indicating an error. As a result, we
// try to move this to the front of the queue to ensure the autograd
// engine threads pick up this error soon.
return std::numeric_limits<int>::max();
}
}
CheckpointValidGuard::CheckpointValidGuard(const std::shared_ptr<const GraphTask>& graph_task) {
prev_checkpoint_valid_state = checkpoint_valid;
checkpoint_valid = graph_task->can_checkpoint() && prev_checkpoint_valid_state;
}
CheckpointValidGuard::~CheckpointValidGuard() {
checkpoint_valid = prev_checkpoint_valid_state;
}
auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
{
// Lock mutex for writing to heap_
std::lock_guard<std::mutex> lock(mutex_);
if (incrementOutstandingTasks) {
std::shared_ptr<GraphTask> graph_task = item.base_.lock();
TORCH_INTERNAL_ASSERT(graph_task, "GraphTask is no longer valid!");
++graph_task->outstanding_tasks_;
}
heap_.push(std::move(item));
}
not_empty_.notify_one();
}
auto ReadyQueue::pushShutdownTask() -> void {
{
std::lock_guard<std::mutex> lock(mutex_);
heap_.push(NodeTask({}, nullptr, InputBuffer(0), true));
}
not_empty_.notify_one();
}
size_t ReadyQueue::size() const {
// Lock mutex for accesses to heap_
std::unique_lock<std::mutex> lock(mutex_);
return heap_.size();
}
auto ReadyQueue::pop() -> NodeTask {
// Lock mutex for accesses to heap_
std::unique_lock<std::mutex> lock(mutex_);
not_empty_.wait(lock, [this]{ return !heap_.empty(); });
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto task = std::move(const_cast<NodeTask&>(heap_.top())); heap_.pop();
return task;
}
bool ReadyQueue::empty() const {
// Lock mutex for accesses to heap_
std::unique_lock<std::mutex> lock(mutex_);
return heap_.empty();
}
Engine::Engine() : max_recursion_depth_(MAX_DEPTH), non_reentrant_device_thread_count_(0) {}
Engine::~Engine() {
stop();
}
// Send shutdown tasks to all device_ready_queues_ if no backward tasks are running
// Even though readyQueue should be empty, shutdown tasks have the highest priority
void Engine::stop() {
if (stopped_) {
return;
}
stopped_ = true;
// Under some conditions, autograd threads can hang on shutdown
// Do not wait for them to shutdown indefinitely but rely on timeout
auto wait_duration_str = getenv("TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT");
if (!wait_duration_str) {
wait_duration_str = "10.0";
}
auto wait_duration = std::atof(wait_duration_str);
bool noBackward = true;
for (auto& queue: device_ready_queues_) {
noBackward = noBackward && queue->empty();
}
if (noBackward && wait_duration > 0.0f) {
for (auto& queue : device_ready_queues_) {
queue->pushShutdownTask();
}
// Do not wait for termination of global threads on Windows
// Because CRT terminates DLL threads before calling
// global object destructors
#if !defined(_WIN32) || defined(C10_USE_MSVC_STATIC_RUNTIME)
using namespace std::chrono_literals;
// Set a deadline for how long it is OK to wait device threads to shutdown
auto wait_deadline = std::chrono::steady_clock::now() + wait_duration * 1.0s;
std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
while(non_reentrant_device_thread_count_.load() != 0) {
if (non_reentrant_device_thread_condvar_.wait_until(lk, wait_deadline) == std::cv_status::timeout) {
break;
}
}
#endif
}
// Otherwise threads are leaked
}
void Engine::release_workers() {
std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
non_reentrant_device_thread_count_.store(0);
non_reentrant_device_thread_condvar_.notify_one();
}
void Engine::increment_non_reentrant_thread_count() {
std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
non_reentrant_device_thread_count_.fetch_add(1);
non_reentrant_device_thread_condvar_.notify_one();
}
void Engine::decrement_non_reentrant_thread_count() {
std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
non_reentrant_device_thread_count_.fetch_sub(1);
non_reentrant_device_thread_condvar_.notify_one();
}
void Engine::thread_init(int device, const std::shared_ptr<ReadyQueue>& ready_queue, bool should_increment) {
if (should_increment) {
increment_non_reentrant_thread_count();
}
at::init_num_threads();
// Note [Allocating GPUs to autograd threads]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What's our strategy here? Originally, the autograd engine was written
// with only CUDA in mind. We allocate one thread to handle all CPU
// operations, and a thread per CUDA device.
//
// But what if we have OTHER devices? There are two plausible
// strategies:
//
// - We can allocate threads equal to max(num_cuda_devices, num_xla_devices,
// ...) and colocate cuda device 0 with xla device 0
// - We can allocate threads equal to sum(num_cuda_devices, num_xla_devices,
// ...) keeping everyone separate.
//
// We don't have any good reason to prefer one or the other, so we've
// arbitrarily picked to colocate devices. Maybe the other approach is
// better.
set_device(device);
// initialize each device thread's thread local ready queue with the ready queue
// that is created before the thread initialization
init_local_ready_queue(ready_queue);
std::shared_ptr<GraphTask> graph_task = nullptr;
thread_main(graph_task);
if (should_increment) {
// Decrement the count during shutdown if we incremented earlier.
decrement_non_reentrant_thread_count();
}
}
GraphTaskGuard::GraphTaskGuard(std::shared_ptr<GraphTask> graph_task) {
last_graph_task_ = std::move(current_graph_task);
current_graph_task = std::move(graph_task);
}
GraphTaskGuard::~GraphTaskGuard() {
restore_current_graph_task();
}
void GraphTaskGuard::restore_current_graph_task() {
current_graph_task = std::move(last_graph_task_);
}
// NOTE: graph_tasks do not necessarily form a stack. Imagine this
// case:
//
// +----> Eval1
// Root
// +----> Eval2
//
// Once Root is executed, both Eval1 and Eval2 are added to the ready queue.
// Next, Eval1 is run and this causes the worker to enter thread_main again.
// Then, it pops the next task from the queue, but at this point it is Eval2.
// It enters thread_main once again, but now with graph_task of Eval2, which is
// completely unrelated to that of Eval1 (it's not a recursive call).
// It's all ok and is handled right now, but it should be accounted for
// in case this code is to be changed.
//
// thread_main is used by:
// 1). autograd threads for devices (i.e. CUDA, XLA)
// 2). the caller/owning thread of the backward call on CPU (sync mode)
// 3). Renetrant backward that invoked by either 1) or 2)
// The exit conditions are different for the above three cases.
// For 1), we are spinning on running the thread_main on device autograd
// threads throughout the Engine lifetime, thread_main will get
// terminated during Engine destruction by pushing shutdown tasks
// For 2), the owning thread of the backward call drives the thread_main
// synchronously until the graph_task of that owning thread is
// completed and exit the thread_main to continue executing the
// result of caller's code.
// For 3), the reentrant backward that invokes
// thread_main, either from 1) or 2), will not spin and will exit as
// long as graph_task is completed and notify the owning thread as
// needed.
auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
// When graph_task is nullptr, this is a long running thread that processes
// tasks (ex: device threads). When graph_task is non-null (ex: reentrant
// backwards, user thread), this function is expected to exit once that
// graph_task complete.
// local_ready_queue should already been initialized when we get into thread_main
TORCH_INTERNAL_ASSERT(local_ready_queue != nullptr);
while (graph_task == nullptr || !graph_task->future_result_->completed()) {
// local_graph_task represents the graph_task we retrieve from the queue.
// The outer graph_task represents the overall graph_task we need to execute
// for reentrant execution.
std::shared_ptr<GraphTask> local_graph_task;
{
// Scope this block of execution since NodeTask is not needed after this
// block and can be deallocated (release any references to grad tensors
// as part of inputs_).
NodeTask task = local_ready_queue->pop();
// This will only work if the worker is running a non backward task
// TODO Needs to be fixed this to work in all cases
if (task.isShutdownTask_) {
C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown");
break;
}
if (!(local_graph_task = task.base_.lock())) {
// GraphTask for function is no longer valid, skipping further
// execution.
continue;
}
if (task.fn_ && !local_graph_task->has_error_.load()) {
AutoGradMode grad_mode(local_graph_task->grad_mode_);
try {
// The guard sets the thread_local current_graph_task on construction
// and restores it on exit. The current_graph_task variable helps
// queue_callback() to find the target GraphTask to append final
// callbacks.
GraphTaskGuard guard(local_graph_task);
NodeGuard ndguard(task.fn_);
evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
} catch (std::exception& e) {
thread_on_exception(local_graph_task, task.fn_, e);
}
}
}
// Decrement the outstanding tasks.
--local_graph_task->outstanding_tasks_;
// Check if we've completed execution.
if (local_graph_task->completed()) {
local_graph_task->mark_as_completed_and_run_post_processing();
auto base_owner = local_graph_task->owner_;
// The current worker thread finish the graph_task, but the owning thread
// of the graph_task might be sleeping on pop() if it does not have work.
// So we need to send a dummy function task to the owning thread just to
// ensure that it's not sleeping, so that we can exit the thread_main.
// If it has work, it might see that graph_task->outstanding_tasks_ == 0
// before it gets to the task, but it's a no-op anyway.
//
// NB: This is not necessary if the current thread is the owning thread.
if (worker_device != base_owner) {
// Synchronize outstanding_tasks_ with queue mutex
std::atomic_thread_fence(std::memory_order_release);
ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
}
}
}
}
// Reentrant call will re-use the graph_task's owner thread ready_queue for
// queueing tasks (NOTE: this is not true in the async_mode of the engine).
// While we can create separate ready queue for each new reentrant
// thread, but sharing the same cpu_ready_queue with parent thread is a
// performance improvement and cuda thread still have to do the same thing.
void Engine::reentrant_thread_init() {
at::init_num_threads();
auto tp_shared = thread_pool_shared_;
while(true) {
std::unique_lock<std::mutex> lk(tp_shared->mutex_);
++thread_pool_shared_->num_workers_;
tp_shared->work_.wait(lk, [&tp_shared]{ return !tp_shared->graphtasks_queue_.empty();});
--thread_pool_shared_->num_workers_;
auto task = tp_shared->graphtasks_queue_.front();
tp_shared->graphtasks_queue_.pop();
lk.unlock();
std::shared_ptr<GraphTask> graph_task;
if (!(graph_task = task.lock())) {
LOG(INFO) << "GraphTask has expired, skipping reentrant execution";
continue;
}
set_device(graph_task->owner_);
// set the local_ready_queue to the ready queue on the graph_task->owner_ device
local_ready_queue = ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_);
total_depth = graph_task->reentrant_depth_;
thread_main(graph_task);
}
}
void Engine::thread_on_exception(
std::shared_ptr<GraphTask> graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e) {
graph_task->set_exception(std::current_exception(), fn);
}
bool GraphTask::completed() {
return outstanding_tasks_.load() == 0 ||
(exit_on_error_ && has_error_.load());
}
void GraphTask::mark_as_completed_and_run_post_processing() {
// Allow only one thread one attempt to process this logic.
if (future_completed_.exchange(true)) {
// Future is already marked complete, or being marked as such.
// In case the marking complete is only in progress, we add a
// wait() to guarantee the future is marked complete on exit.
future_result_->wait();
return;
}
try {
// Run post processing, before marking the future as complete.
// Drop lock prior to completing, to avoid holding across callbacks.
std::unique_lock<std::mutex> lock(mutex_);
exec_post_processing();
std::vector<Variable> vars = std::move(captured_vars_);
// Need to unlock before we call markCompleted to avoid holding locks
// when the callbacks are called.
lock.unlock();
// NOLINTNEXTLINE(performance-move-const-arg)
future_result_->markCompleted(std::move(vars));
} catch (std::exception& e) {
future_result_->setErrorIfNeeded(std::current_exception());
}
}
void GraphTask::exec_post_processing() {
if (!not_ready_.empty()) {
throw std::runtime_error("could not compute gradients for some functions");
}
// set the thread_local current_graph_task_ as more callbacks can be installed
// by existing final callbacks.
GraphTaskGuard guard(shared_from_this());
// Lock mutex during each iteration for accessing final_callbacks.size()
// Unlocking is necessary, because the callback can register
// more callbacks (or they can be registered from other threads
// while it's waiting.
std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_);
// caller_current_streams_ with nullopt entries removed
std::vector<c10::Stream> caller_current_streams_filtered;
// See Note [Streaming backwards].
// Syncs caller_current_stream with leaf streams, so final_callbacks may use
// any grad on its device's current stream.
if (leaf_streams.size() > 0) {
for (const auto& leaf_stream : leaf_streams) {
// stash_current_streams() stashed streams for all device IDs that already had a
// CUDA context before the GraphTask executed. For inactive devices, it stashed
// a c10::nullopt. I don't expect GraphTask's backward pass ran leaf nodes on
// any new devices, so the stashed streams should be enough.
// If leaf_stream.device_index() happens to be for a new device,
// operator* on the c10::nullopt should throw an error.
const auto caller_current_stream = *caller_current_streams_[leaf_stream.device_index()];
if (caller_current_stream != leaf_stream) {
auto event = c10::Event{c10::DeviceType::CUDA};
event.record(leaf_stream);
caller_current_stream.wait(event);
}
}
caller_current_streams_filtered.reserve(caller_current_streams_.size());
for (const auto& opt_stream : caller_current_streams_) {
if (opt_stream.has_value()) {
caller_current_streams_filtered.push_back(*opt_stream);
}
}
}
{
// final_callbacks run on the per-device caller_current_streams (the ambient streams
// surrounding the user's call to backward()). This has two benefits:
// 1. caller_current_streams have been synced with leaf_streams, so callbacks may
// safely access any grad.
// 2. The callback's results can safely be used on (user-facing) caller_current_streams
// after backward().
c10::MultiStreamGuard g(caller_current_streams_filtered);
// WARNING: Don't use a range-for loop here because more callbacks may be
// added in between callback calls, so iterators may become invalidated.
// NOLINTNEXTLINE(modernize-loop-convert)
for (size_t i = 0; i < final_callbacks_.size(); ++i) {
cb_lock.unlock();
final_callbacks_[i]();
cb_lock.lock();
}
}
}
void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
if (!has_error_.exchange(true)) {
if (AnomalyMode::is_enabled() && fn) {
fn->metadata()->print_stack(fn->name());
}
}
}
void GraphTask::set_exception(
std::exception_ptr eptr,
const std::shared_ptr<Node>& fn) {
set_exception_without_signal(fn);
if (!future_completed_.exchange(true)) {
// NOLINTNEXTLINE(performance-move-const-arg)
future_result_->setError(std::move(eptr));
}
}
static variable_list call_pre_hooks(Node& fn, variable_list inputs) {
for (const auto& hook : fn.pre_hooks()) {
inputs = (*hook)(inputs);
}
return inputs;
}
static variable_list call_post_hooks(Node& fn, variable_list outputs, const variable_list& inputs) {
for (const auto& hook : fn.post_hooks()) {
outputs = (*hook)(outputs, inputs);
}
return outputs;
}
static bool is_compatible_type(const at::TensorOptions& expected, const at::TensorOptions& actual) {
// Types are compatible if they exactly match or if the gradient is a sparse
// version of the expected type.
return expected.type_equal(actual) || (actual.is_sparse() && expected.device().type() == actual.device().type());
}
void set_device(int device) {
// NB: We MUST NOT construct the guard for device CPU,
// as in some settings we compile with cuda, but
// have lazy stubs for CUDA functionality (so actually
// attempting to setup a guard(CPU_DEVICE) will cause an
// error, because it will still query cudaGetDevice).
//
// Don't use DeviceGuard here because its destructor may be called before the
// device is reset. This is fine because the device is thread local.
if (device != CPU_DEVICE) {
for(const auto i : c10::irange(static_cast<size_t>(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES))) {
auto* impl = c10::impl::device_guard_impl_registry[i].load();
if (impl && device < impl->deviceCount()) {
impl->setDevice(at::Device(static_cast<c10::DeviceType>(i), device));
}
}
}
worker_device = device;
}
void validate_outputs(
const edge_list& edges,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error) {
if (grads.size() != edges.size()) {
std::stringstream ss;
ss << "invalid number of gradients - expected ";
ss << edges.size() << ", but got " << grads.size();
AT_ERROR(format_error(ss.str()));
}
for(const auto i : c10::irange(grads.size())) {
const auto& edge = edges[i];
if (!edge.is_valid()) continue;
const auto& metadata = edge.function->input_metadata(edge.input_nr);
auto& grad = grads[i];
if (!grad.defined()) {
// FIXME: TestJit.test_ge_optimized fails this assertion.
// std::stringstream ss;
// ss << "undefined gradient at index " << i;
// AT_ERROR(format_error(ss.str()));
continue;
}
if (!grad.sizes().equals(metadata.shape())) {
if (!at::is_expandable_to(metadata.shape(), grad.sizes())) {
std::stringstream ss;
ss << "invalid gradient at index " << i << " - got ";
ss << grad.sizes() << " but expected shape compatible with ";
ss << metadata.shape();
AT_ERROR(format_error(ss.str()));
}
grad = at::sum_to(std::move(grad), metadata.shape());
}
bool input_is_complex = isComplexType(c10::typeMetaToScalarType(metadata.options().dtype()));
bool grad_is_complex = isComplexType(grad.scalar_type());
TORCH_CHECK(isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex));
if (c10::typeMetaToScalarType(metadata.options().dtype()) != grad.scalar_type()) {
grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype()));
}
if (grad.device() != metadata.device() &&
grad.dim() == 0) {
grad = grad.to(metadata.device());
}
if (!is_compatible_type(metadata.options(), grad.options())) {
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected type ";
ss << metadata.options() << " but got " << grad.options();
AT_ERROR(format_error(ss.str()));
}
auto grad_device = grad.device();
if (grad_device != metadata.device()) {
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected device ";
ss << metadata.device() << " but got " << grad_device;
AT_ERROR(format_error(ss.str()));
}
// We should not build graph for Tensors that are not differentiable
TORCH_INTERNAL_ASSERT(isDifferentiableType(grad.scalar_type()));
}
}
static variable_list call_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputBuffer) {
CheckpointValidGuard cpvguard(graph_task);
auto& fn = *func;
auto inputs =
call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
if (!graph_task->keep_graph_) {
fn.will_release_variables();
}
const auto has_post_hooks = !fn.post_hooks().empty();
variable_list outputs;
if (has_post_hooks) {
// In functions/accumulate_grad.cpp, there is some logic to check the
// conditions under which the incoming gradient can be stolen directly
// (which elides a deep copy) instead of cloned. One of these conditions
// is that the incoming gradient's refcount must be 1 (nothing else is
// referencing the same data). Stashing inputs_copy here bumps the
// refcount, so if post hooks are employed, it's actually still ok for
// accumulate_grad.cpp to steal the gradient if the refcount is 2.
//
// "new_grad.use_count() <= 1 + !post_hooks().empty()" in
// accumulate_grad.cpp accounts for this, but also creates a silent
// dependency between engine.cpp (ie, this particular engine
// implementation) and accumulate_grad.cpp.
//
// If you change the logic here, make sure it's compatible with
// accumulate_grad.cpp.
auto inputs_copy = inputs;
outputs = fn(std::move(inputs_copy));
} else {
outputs = fn(std::move(inputs));
}
validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
std::ostringstream ss;
ss << "Function " << fn.name() << " returned an " << msg;
return ss.str();
});
if(has_post_hooks){
// NOLINTNEXTLINE(bugprone-use-after-move)
return call_post_hooks(fn, std::move(outputs), inputs);
}
return outputs;
}
void Engine::evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputs,
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
// Set the ThreadLocalState before calling the function.
// NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
// always saves ThreadLocalState without grad_mode.
at::ThreadLocalStateGuard tls_guard(graph_task->thread_locals_);
// The InputBuffer::adds that supplied incoming grads took pains to
// ensure they're safe to consume in the context of the present
// func's stream (if applicable). So we guard onto that stream
// before working with the grads in any capacity.
const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
// If exec_info_ is not empty, we have to instrument the execution
auto& exec_info_ = graph_task->exec_info_;
if (!exec_info_.empty()) {
auto& fn_info = exec_info_.at(func);
if (auto* capture_vec = fn_info.captures_.get()) {
// Lock mutex for writing to graph_task->captured_vars_.
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (const auto& capture : *capture_vec) {
auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
captured_grad = inputs[capture.input_idx_];
for (auto& hook : capture.hooks_) {
captured_grad = (*hook)(captured_grad);
}
if (opt_parent_stream) {
// No need to take graph_task->mutex_ here, we already hold it
graph_task->leaf_streams.emplace(*opt_parent_stream);
}
}
}
if (!fn_info.needed_) {
// Skip execution if we don't need to execute the function.
return;
}
}
auto outputs = call_function(graph_task, func, inputs);
auto& fn = *func;
if (!graph_task->keep_graph_) {
fn.release_variables();
}
int num_outputs = outputs.size();
if (num_outputs == 0) { // Note: doesn't acquire the mutex
// Records leaf stream (if applicable)
// See Note [Streaming backwards]
if (opt_parent_stream) {
std::lock_guard<std::mutex> lock(graph_task->mutex_);
graph_task->leaf_streams.emplace(*opt_parent_stream);
}
return;
}
if (AnomalyMode::is_enabled()) {
AutoGradMode grad_mode(false);
for (const auto i : c10::irange(num_outputs)) {
auto& output = outputs[i];
at::OptionalDeviceGuard guard(device_of(output));
if (output.defined() && isnan(output).any().item<uint8_t>()) {
std::stringstream ss;
ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output.";
throw std::runtime_error(ss.str());
}
}
}
// Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and cpu_ready_queue_ below
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (const auto i : c10::irange(num_outputs)) {
auto& output = outputs[i];
const auto& next = fn.next_edge(i);
if (!next.is_valid()) continue;
// Check if the next function is ready to be computed
bool is_ready = false;
auto& dependencies = graph_task->dependencies_;
auto it = dependencies.find(next.function.get());
if (it == dependencies.end()) {
auto name = next.function->name();
throw std::runtime_error(std::string("dependency not found for ") + name);
} else if (--it->second == 0) {
dependencies.erase(it);
is_ready = true;
}
auto& not_ready = graph_task->not_ready_;
auto not_ready_it = not_ready.find(next.function.get());
if (not_ready_it == not_ready.end()) {
// Skip functions that aren't supposed to be executed
if (!exec_info_.empty()) {
auto it = exec_info_.find(next.function.get());
if (it == exec_info_.end() || !it->second.should_execute()) {
continue;
}
}
// No buffers have been allocated for the function
InputBuffer input_buffer(next.function->num_inputs());
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
} else {
not_ready.emplace(next.function.get(), std::move(input_buffer));
}
} else {
// The function already has a buffer
auto &input_buffer = not_ready_it->second;
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
not_ready.erase(not_ready_it);
}
}
}
}
inline static uint64_t compute_min_topological_nr(const edge_list& outputs) {
// Computes the mininum topological number among all the outputs
if (outputs.empty()) {
return 0;
}
auto min_topo_nr = std::numeric_limits<uint64_t>::max();
for (auto & output_edge : outputs) {
auto topo_nr = output_edge.function.get()->topological_nr();
min_topo_nr = (min_topo_nr < topo_nr) ? min_topo_nr : topo_nr;
}
return min_topo_nr;
}
auto Engine::compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr) -> void {
// Computes the number of dependencies for each function which requires grad
std::unordered_set<Node*> seen;
std::vector<Node*> queue { root };
bool might_use_cuda = at::globalContext().hasCUDA();
bool will_use_cuda = false;
// Queue contains all nodes that will start propagating gradients.
// We no longer have to expand functions that don't require grad.
auto& dependencies = task.dependencies_;
while (!queue.empty()) {
auto fn = queue.back(); queue.pop_back();
if (fn->topological_nr() < min_topo_nr) {
continue;
}
if (might_use_cuda && !will_use_cuda) {
will_use_cuda = fn->stream(c10::DeviceType::CUDA).has_value();
}
for (const auto& edge : fn->next_edges()) {
if (auto next_ptr = edge.function.get()) {
dependencies[next_ptr] += 1;
const bool was_inserted = seen.insert(next_ptr).second;
if (was_inserted) queue.push_back(next_ptr);
}
}
}
if (will_use_cuda) {
// Collects current streams for devices where this process has a context,
// so GraphTask::exec_post_processing can sync them with leaf_streams.
task.stash_current_streams();
}
}
auto Engine::execute(const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs) -> variable_list {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
validate_outputs(roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
return msg;
});
if (accumulate_grad && create_graph) {
TORCH_WARN_ONCE(
"Using backward() with create_graph=True will create a reference cycle "
"between the parameter and its gradient which can cause a memory leak. "
"We recommend using autograd.grad when creating the graph to avoid this. "
"If you have to use this function, make sure to reset the .grad fields of "
"your parameters to None after use to break the cycle and avoid the leak.");
}
// accumulate_grad is true if and only if the frontend call was to
// grad(), not backward(). grad() returns the sum of the gradients
// w.r.t. the inputs and thus needs the inputs to be present.
TORCH_CHECK_VALUE(accumulate_grad || !outputs.empty(),
"grad requires non-empty inputs.");
// A fresh first time Engine::execute call should start on the CPU device, initialize
// a new thread local ready queue on CPU or reuse the existing one (if there is one
// allocated already, i.e. consecutive backward calls, re-entrant backward calls),
// then memoize the local_ready_queue in GraphTask
init_local_ready_queue();
bool not_reentrant_backward_call = worker_device == NO_DEVICE;
auto graph_task = std::make_shared<GraphTask>(
/* keep_graph */ keep_graph,
/* create_graph */ create_graph,
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
/* cpu_ready_queue */ local_ready_queue);
// If we receive a single root, skip creating extra root node
bool skip_dummy_node = roots.size() == 1;
auto graph_root = skip_dummy_node ?
roots.at(0).function :
std::make_shared<GraphRoot>(roots, inputs);
auto min_topo_nr = compute_min_topological_nr(outputs);
// Now compute the dependencies for all executable functions
compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);
if (!outputs.empty()) {
graph_task->init_to_execute(*graph_root, outputs, accumulate_grad, min_topo_nr);