Skip to content

Commit

Permalink
Fix autograd engine callback error propagation from device thread (py…
Browse files Browse the repository at this point in the history
…torch#113702)

The existing try-catch doesn't work because it doesn't call err.persist(). This is in contrast to the try-catch for evaluate_function which does work because it calls into python_engine's thread_on_exception which calls persist.

Calling persist on a python_error stashes the PyErr state from the thread-local PyThreadState onto the python_error object, so that when this error object is stored onto the future and passed back to the calling cpu thread, python_engine's execute try-catch can then err.restore() the error state. Finally, the python_engine's execute would re-raise so that this is re-caught by the HANDLE_TH_ERRORS macro.

Fixes pytorch#75750

Pull Request resolved: pytorch#113702
Approved by: https://github.com/albanD
  • Loading branch information
soulitzer authored and pytorchmergebot committed Nov 17, 2023
1 parent 957312a commit c435b8c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 12 deletions.
1 change: 1 addition & 0 deletions test/inductor/test_compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def wrapped(self: EagerAutogradTests):
"test_unrelated_inputs", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
"test_will_engine_execute_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd
"test_backward_to_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd
"test_callback_propagates_errors_from_device_thread", # AssertionError: "blah" does not match "call_method UserDefinedObj..."
}

if not HAS_CUDA:
Expand Down
34 changes: 34 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6336,6 +6336,22 @@ def backward(ctx, grad):

self.assertEqual(called[0], 2)

@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_callback_propagates_errors_from_device_thread(self):
def callback():
raise RuntimeError("blah")

def hook_with_callback(*args):
torch.autograd.Variable._execution_engine.queue_callback(callback)

t = torch.tensor([1., 2.], requires_grad=True, device=torch.device("cuda"))
t.register_hook(hook_with_callback)
output = t ** 2
loss = output.sum()

with self.assertRaisesRegex(RuntimeError, "blah"):
loss.backward()

def _test_reentrant_with_callbacks(self, install_callbacks_in_depths):
counter = {}
counter["inner"] = 0
Expand Down Expand Up @@ -11281,6 +11297,24 @@ def test_set_multithreading_enabled_as_context_manager_and_function(self):
torch.autograd.set_multithreading_enabled(True)
self.assertTrue(torch.autograd.is_multithreading_enabled())

@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_custom_function_propagates_errors_from_device_thread(self):
class MyFunc(Function):
@staticmethod
def forward(ctx, x):
return x

@staticmethod
def backward(ctx, gO):
raise RuntimeError("blah")
return gO

t = torch.tensor([1., 2.], requires_grad=True, device=torch.device("cuda"))
out = MyFunc.apply(t).sum()

with self.assertRaisesRegex(RuntimeError, "blah"):
out.backward()

class TestNestedCheckpoint(TestCase):
@staticmethod
def grad(fn):
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
local_graph_task->cpu_ready_queue_);
}
} catch (std::exception& e) {
// See Note [ Persisting PyErr state across autograd engine threads ]
thread_on_exception(local_graph_task, task.fn_, e);
}
}
Expand Down
22 changes: 20 additions & 2 deletions torch/csrc/autograd/python_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void PythonEngine::thread_on_exception(
std::shared_ptr<GraphTask> graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e) {
// See Note [ Persisting PyErr state across autograd engine threads ]
auto python_err = dynamic_cast<python_error*>(&e);
if (python_err) {
python_err->persist();
Expand Down Expand Up @@ -392,8 +393,25 @@ PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) {
engine.queue_callback([callback]() {
pybind11::gil_scoped_acquire gil;
THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
if (!result)
throw python_error();
if (!result) {
// Note [ Persisting PyErr state across autograd engine threads ]
//
// Since the autograd engine is multi-threaded, and Python error state is
// local to each thread, it must preserve the python error from the worker
// thread and rethrow it as-is in the calling thread. This is done via
// persisting the error in the two places that can encounter Python
// errors: (1) evaluate function and (2) queued callbacks.
//
// TODO: the engine is not actually responsible for persisting the error
// in the custom autograd Function case today! See the note above
// `raise_python_error()` function in python_function.cpp and
// python_hooks.cpp for more details. Persisting an extra time in the
// engine is fine because doing so is a no-op when the python_error has
// already been persisted.
python_error err;
err.persist();
throw std::move(err);
}
});
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
Expand Down
18 changes: 8 additions & 10 deletions torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,14 @@ PyObject* THPGradientEdgeClass = nullptr;
// Anonymous namespace for helpful functions used in this file
namespace {

// Throw a python_error with the PyErr state persisted, so that we
// don't lose the error state if the GIL is released when we don't
// have a PyThreadState created beforehand, this is made so that
// even for pure C++ thread without a pre-created PyThreadState could
// also capture the correct error message.
// TODO: This is a temporary approach to allow C++ thread to correctly
// capture Python Error in autograd, remove this when c10 thread pool
// allow to do one time initialization.
// see discussion in https://github.com/pytorch/pytorch/pull/34845
// Follow up issue: https://github.com/pytorch/pytorch/issues/35006
// TODO: We shouldn't need to call this function because the engine
// can already persist the errors for us. This still seems to be
// needed for the DistEngine however.
//
// python test/distributed/rpc/test_tensorpipe_agent.py -k
// test_backward_autograd_engine_error
//
// See Note [ Persisting PyErr state across autograd engine threads ]
void throw_python_error() {
python_error err;
err.persist();
Expand Down

0 comments on commit c435b8c

Please sign in to comment.