Skip to content

Commit

Permalink
Capture the threading.Thread.name attribute
Browse files Browse the repository at this point in the history
Previously we only captured the thread name as set by `prctl`, but not
the thread name as returned by `threading.current_thread().name`.

Begin capturing the name for the Python thread as well. We retain only
the last name set for each thread, so assignments to `Thread.name`
override earlier calls to `prctl(PR_SET_NAME)`, and vice versa.

This implementation uses a custom descriptor to intercept assignments to
`Thread._name` and `Thread._ident` in order to detect when a thread has
a name or a thread id assigned to it. Because this is tricky and a bit
fragile (poking at the internals of `Thread`), I've implemented that
descriptor in a Python module. At least that way if it ever breaks, it
should be a bit easier for someone to investigate.

Signed-off-by: Matt Wozniski <[email protected]>
  • Loading branch information
godlygeek authored and pablogsal committed May 30, 2024
1 parent ceb29b5 commit 0943537
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 1 deletion.
1 change: 1 addition & 0 deletions news/562.feature.2.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Capture the name attribute of Python `threading.Thread` objects.
13 changes: 12 additions & 1 deletion src/memray/_memray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ from ._destination import FileDestination
from ._destination import SocketDestination
from ._metadata import Metadata
from ._stats import Stats
from ._thread_name_interceptor import ThreadNameInterceptor


def set_log_level(int level):
Expand Down Expand Up @@ -691,7 +692,6 @@ cdef class Tracker:

@cython.profile(False)
def __enter__(self):

if NativeTracker.getTracker() != NULL:
raise RuntimeError("No more than one Tracker instance can be active at the same time")

Expand All @@ -700,6 +700,14 @@ cdef class Tracker:
raise RuntimeError("Attempting to use stale output handle")
writer = move(self._writer)

for attr in ("_name", "_ident"):
assert not hasattr(threading.Thread, attr)
setattr(
threading.Thread,
attr,
ThreadNameInterceptor(attr, NativeTracker.registerThreadNameById),
)

self._previous_profile_func = sys.getprofile()
self._previous_thread_profile_func = threading._profile_hook
threading.setprofile(start_thread_trace)
Expand All @@ -722,6 +730,9 @@ cdef class Tracker:
sys.setprofile(self._previous_profile_func)
threading.setprofile(self._previous_thread_profile_func)

for attr in ("_name", "_ident"):
delattr(threading.Thread, attr)


def start_thread_trace(frame, event, arg):
if event in {"call", "c_call"}:
Expand Down
27 changes: 27 additions & 0 deletions src/memray/_memray/tracking_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ Tracker::trackAllocationImpl(
hooks::Allocator func,
const std::optional<NativeTrace>& trace)
{
registerCachedThreadName();
PythonStackTracker::get().emitPendingPushesAndPops();

if (d_unwind_native_frames) {
Expand Down Expand Up @@ -871,6 +872,7 @@ Tracker::trackAllocationImpl(
void
Tracker::trackDeallocationImpl(void* ptr, size_t size, hooks::Allocator func)
{
registerCachedThreadName();
AllocationRecord record{reinterpret_cast<uintptr_t>(ptr), size, func};
if (!d_writer->writeThreadSpecificRecord(thread_id(), record)) {
std::cerr << "Failed to write output, deactivating tracking" << std::endl;
Expand Down Expand Up @@ -963,12 +965,37 @@ void
Tracker::registerThreadNameImpl(const char* name)
{
RecursionGuard guard;
dropCachedThreadName();
if (!d_writer->writeThreadSpecificRecord(thread_id(), ThreadRecord{name})) {
std::cerr << "memray: Failed to write output, deactivating tracking" << std::endl;
deactivate();
}
}

void
Tracker::registerCachedThreadName()
{
if (d_cached_thread_names.empty()) {
return;
}

auto it = d_cached_thread_names.find((uint64_t)(pthread_self()));
if (it != d_cached_thread_names.end()) {
auto& name = it->second;
if (!d_writer->writeThreadSpecificRecord(thread_id(), ThreadRecord{name.c_str()})) {
std::cerr << "memray: Failed to write output, deactivating tracking" << std::endl;
deactivate();
}
d_cached_thread_names.erase(it);
}
}

void
Tracker::dropCachedThreadName()
{
d_cached_thread_names.erase((uint64_t)(pthread_self()));
}

frame_id_t
Tracker::registerFrame(const RawFrame& frame)
{
Expand Down
24 changes: 24 additions & 0 deletions src/memray/_memray/tracking_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,27 @@ class Tracker
}
}

inline static void registerThreadNameById(uint64_t thread, const char* name)
{
if (RecursionGuard::isActive || !Tracker::isActive()) {
return;
}
RecursionGuard guard;

std::unique_lock<std::mutex> lock(*s_mutex);
Tracker* tracker = getTracker();
if (tracker) {
if (thread == (uint64_t)(pthread_self())) {
tracker->registerThreadNameImpl(name);
} else {
// We've got a different thread's name, but don't know what id
// has been assigned to that thread (if any!). Set this update
// aside to be handled later, from that thread.
tracker->d_cached_thread_names.emplace(thread, name);
}
}
}

// RawFrame stack interface
bool pushFrame(const RawFrame& frame);
bool popFrames(uint32_t count);
Expand Down Expand Up @@ -359,6 +380,7 @@ class Tracker
const bool d_trace_python_allocators;
linker::SymbolPatcher d_patcher;
std::unique_ptr<BackgroundThread> d_background_thread;
std::unordered_map<uint64_t, std::string> d_cached_thread_names;

// Methods
static size_t computeMainTidSkip();
Expand All @@ -373,6 +395,8 @@ class Tracker
void invalidate_module_cache_impl();
void updateModuleCacheImpl();
void registerThreadNameImpl(const char* name);
void registerCachedThreadName();
void dropCachedThreadName();
void registerPymallocHooks() const noexcept;
void unregisterPymallocHooks() const noexcept;

Expand Down
4 changes: 4 additions & 0 deletions src/memray/_memray/tracking_api.pxd
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from _memray.record_writer cimport RecordWriter
from libc.stdint cimport uint64_t
from libcpp cimport bool
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
Expand Down Expand Up @@ -31,3 +32,6 @@ cdef extern from "tracking_api.h" namespace "memray::tracking_api":

@staticmethod
void handleGreenletSwitch(object, object) except+

@staticmethod
void registerThreadNameById(uint64_t, const char*) except+
23 changes: 23 additions & 0 deletions src/memray/_thread_name_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import threading
from typing import Callable


class ThreadNameInterceptor:
"""Record the name of each threading.Thread for Memray's reports.
The name can be set either before or after the thread is started, and from
either the same thread or a different thread. Whenever an assignment to
either `Thread._name` or `Thread._ident` is performed and the other has
already been set, we call a callback with the thread's ident and name.
"""

def __init__(self, attr: str, callback: Callable[[int, str], None]) -> None:
self._attr = attr
self._callback = callback

def __set__(self, instance: threading.Thread, value: object) -> None:
instance.__dict__[self._attr] = value
ident = instance.__dict__.get("_ident")
name = instance.__dict__.get("_name")
if ident is not None and name is not None:
self._callback(ident, name)
50 changes: 50 additions & 0 deletions tests/integration/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,53 @@ def allocating_function():
(valloc,) = vallocs
assert valloc.size == 1234
assert "my thread name" == valloc.thread_name


def test_setting_python_thread_name(tmpdir):
# GIVEN
output = Path(tmpdir) / "test.bin"
allocator = MemoryAllocator()
name_set_inside_thread = threading.Event()
name_set_outside_thread = threading.Event()
prctl_rc = -1

def allocating_function():
allocator.valloc(1234)
allocator.free()

threading.current_thread().name = "set inside thread"
allocator.valloc(1234)
allocator.free()

name_set_inside_thread.set()
name_set_outside_thread.wait()
allocator.valloc(1234)
allocator.free()

nonlocal prctl_rc
prctl_rc = set_thread_name("set by prctl")
allocator.valloc(1234)
allocator.free()

# WHEN
with Tracker(output):
t = threading.Thread(target=allocating_function, name="set before start")
t.start()
name_set_inside_thread.wait()
t.name = "set outside running thread"
name_set_outside_thread.set()
t.join()

# THEN
expected_names = [
"set before start",
"set inside thread",
"set outside running thread",
"set by prctl" if prctl_rc == 0 else "set outside running thread",
]
names = [
rec.thread_name
for rec in FileReader(output).get_allocation_records()
if rec.allocator == AllocatorType.VALLOC
]
assert names == expected_names

0 comments on commit 0943537

Please sign in to comment.