diff --git a/news/667.bugfix.rst b/news/667.bugfix.rst new file mode 100644 index 0000000000..5ff4954621 --- /dev/null +++ b/news/667.bugfix.rst @@ -0,0 +1 @@ +Fix a race condition that was able to cause strange exception messages if two different threads tried to initialize Memray tracking at once. diff --git a/src/memray/_memray.pyx b/src/memray/_memray.pyx index 6e90677d5d..417f6f2657 100644 --- a/src/memray/_memray.pyx +++ b/src/memray/_memray.pyx @@ -581,6 +581,9 @@ cdef class ProfileFunctionGuard: NativeTracker.forgetPythonStack() +tracker_creation_lock = threading.Lock() + + cdef class Tracker: """Context manager for tracking memory allocations in a Python script. @@ -690,46 +693,48 @@ cdef class Tracker: @cython.profile(False) def __enter__(self): - if NativeTracker.getTracker() != NULL: - raise RuntimeError("No more than one Tracker instance can be active at the same time") - cdef unique_ptr[RecordWriter] writer - if self._writer == NULL: - raise RuntimeError("Attempting to use stale output handle") - writer = move(self._writer) - - for attr in ("_name", "_ident"): - assert not hasattr(threading.Thread, attr) - setattr( - threading.Thread, - attr, - ThreadNameInterceptor(attr, NativeTracker.registerThreadNameById), - ) + with tracker_creation_lock: + if NativeTracker.getTracker() != NULL: + raise RuntimeError("No more than one Tracker instance can be active at the same time") + + if self._writer == NULL: + raise RuntimeError("Attempting to use stale output handle") + writer = move(self._writer) + + for attr in ("_name", "_ident"): + assert not hasattr(threading.Thread, attr) + setattr( + threading.Thread, + attr, + ThreadNameInterceptor(attr, NativeTracker.registerThreadNameById), + ) - self._previous_profile_func = sys.getprofile() - self._previous_thread_profile_func = threading._profile_hook - threading.setprofile(start_thread_trace) + self._previous_profile_func = sys.getprofile() + self._previous_thread_profile_func = threading._profile_hook + threading.setprofile(start_thread_trace) - if "greenlet" in sys.modules: - NativeTracker.beginTrackingGreenlets() + if "greenlet" in sys.modules: + NativeTracker.beginTrackingGreenlets() - NativeTracker.createTracker( - move(writer), - self._native_traces, - self._memory_interval_ms, - self._follow_fork, - self._trace_python_allocators, - ) - return self + NativeTracker.createTracker( + move(writer), + self._native_traces, + self._memory_interval_ms, + self._follow_fork, + self._trace_python_allocators, + ) + return self @cython.profile(False) def __exit__(self, exc_type, exc_value, exc_traceback): - NativeTracker.destroyTracker() - sys.setprofile(self._previous_profile_func) - threading.setprofile(self._previous_thread_profile_func) + with tracker_creation_lock: + NativeTracker.destroyTracker() + sys.setprofile(self._previous_profile_func) + threading.setprofile(self._previous_thread_profile_func) - for attr in ("_name", "_ident"): - delattr(threading.Thread, attr) + for attr in ("_name", "_ident"): + delattr(threading.Thread, attr) def start_thread_trace(frame, event, arg):