From c435b8c10a1b9652a35e1a87cee6c3942356660d Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 17 Nov 2023 09:03:08 -0800 Subject: [PATCH] Fix autograd engine callback error propagation from device thread (#113702) The existing try-catch doesn't work because it doesn't call err.persist(). This is in contrast to the try-catch for evaluate_function which does work because it calls into python_engine's thread_on_exception which calls persist. Calling persist on a python_error stashes the PyErr state from the thread-local PyThreadState onto the python_error object, so that when this error object is stored onto the future and passed back to the calling cpu thread, python_engine's execute try-catch can then err.restore() the error state. Finally, the python_engine's execute would re-raise so that this is re-caught by the HANDLE_TH_ERRORS macro. Fixes https://github.com/pytorch/pytorch/issues/75750 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113702 Approved by: https://github.com/albanD --- test/inductor/test_compiled_autograd.py | 1 + test/test_autograd.py | 34 +++++++++++++++++++++++++ torch/csrc/autograd/engine.cpp | 1 + torch/csrc/autograd/python_engine.cpp | 22 ++++++++++++++-- torch/csrc/autograd/python_function.cpp | 18 ++++++------- 5 files changed, 64 insertions(+), 12 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 5904f5adbeae58..aaef65650a1721 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -535,6 +535,7 @@ def wrapped(self: EagerAutogradTests): "test_unrelated_inputs", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients "test_will_engine_execute_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd "test_backward_to_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd + "test_callback_propagates_errors_from_device_thread", # AssertionError: "blah" does not match "call_method UserDefinedObj..." } if not HAS_CUDA: diff --git a/test/test_autograd.py b/test/test_autograd.py index 10e8ad40969215..40ebf5b86a4f69 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6336,6 +6336,22 @@ def backward(ctx, grad): self.assertEqual(called[0], 2) + @unittest.skipIf(not TEST_CUDA, "test requires CUDA") + def test_callback_propagates_errors_from_device_thread(self): + def callback(): + raise RuntimeError("blah") + + def hook_with_callback(*args): + torch.autograd.Variable._execution_engine.queue_callback(callback) + + t = torch.tensor([1., 2.], requires_grad=True, device=torch.device("cuda")) + t.register_hook(hook_with_callback) + output = t ** 2 + loss = output.sum() + + with self.assertRaisesRegex(RuntimeError, "blah"): + loss.backward() + def _test_reentrant_with_callbacks(self, install_callbacks_in_depths): counter = {} counter["inner"] = 0 @@ -11281,6 +11297,24 @@ def test_set_multithreading_enabled_as_context_manager_and_function(self): torch.autograd.set_multithreading_enabled(True) self.assertTrue(torch.autograd.is_multithreading_enabled()) + @unittest.skipIf(not TEST_CUDA, "test requires CUDA") + def test_custom_function_propagates_errors_from_device_thread(self): + class MyFunc(Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, gO): + raise RuntimeError("blah") + return gO + + t = torch.tensor([1., 2.], requires_grad=True, device=torch.device("cuda")) + out = MyFunc.apply(t).sum() + + with self.assertRaisesRegex(RuntimeError, "blah"): + out.backward() + class TestNestedCheckpoint(TestCase): @staticmethod def grad(fn): diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 2e5ce1807c88d2..d996a024caa55a 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -575,6 +575,7 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { local_graph_task->cpu_ready_queue_); } } catch (std::exception& e) { + // See Note [ Persisting PyErr state across autograd engine threads ] thread_on_exception(local_graph_task, task.fn_, e); } } diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 90f93c5890b5d7..a5112837d0024c 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -106,6 +106,7 @@ void PythonEngine::thread_on_exception( std::shared_ptr graph_task, const std::shared_ptr& fn, std::exception& e) { + // See Note [ Persisting PyErr state across autograd engine threads ] auto python_err = dynamic_cast(&e); if (python_err) { python_err->persist(); @@ -392,8 +393,25 @@ PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) { engine.queue_callback([callback]() { pybind11::gil_scoped_acquire gil; THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)}; - if (!result) - throw python_error(); + if (!result) { + // Note [ Persisting PyErr state across autograd engine threads ] + // + // Since the autograd engine is multi-threaded, and Python error state is + // local to each thread, it must preserve the python error from the worker + // thread and rethrow it as-is in the calling thread. This is done via + // persisting the error in the two places that can encounter Python + // errors: (1) evaluate function and (2) queued callbacks. + // + // TODO: the engine is not actually responsible for persisting the error + // in the custom autograd Function case today! See the note above + // `raise_python_error()` function in python_function.cpp and + // python_hooks.cpp for more details. Persisting an extra time in the + // engine is fine because doing so is a no-op when the python_error has + // already been persisted. + python_error err; + err.persist(); + throw std::move(err); + } }); Py_RETURN_NONE; END_HANDLE_TH_ERRORS diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 9a73d23c1db2a2..93646684a8cdcb 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -56,16 +56,14 @@ PyObject* THPGradientEdgeClass = nullptr; // Anonymous namespace for helpful functions used in this file namespace { -// Throw a python_error with the PyErr state persisted, so that we -// don't lose the error state if the GIL is released when we don't -// have a PyThreadState created beforehand, this is made so that -// even for pure C++ thread without a pre-created PyThreadState could -// also capture the correct error message. -// TODO: This is a temporary approach to allow C++ thread to correctly -// capture Python Error in autograd, remove this when c10 thread pool -// allow to do one time initialization. -// see discussion in https://github.com/pytorch/pytorch/pull/34845 -// Follow up issue: https://github.com/pytorch/pytorch/issues/35006 +// TODO: We shouldn't need to call this function because the engine +// can already persist the errors for us. This still seems to be +// needed for the DistEngine however. +// +// python test/distributed/rpc/test_tensorpipe_agent.py -k +// test_backward_autograd_engine_error +// +// See Note [ Persisting PyErr state across autograd engine threads ] void throw_python_error() { python_error err; err.persist();