Skip to content

Commit

Permalink
Protect against deregistered profile functions in greenlet switches
Browse files Browse the repository at this point in the history
When greenlet tracking is enabled it's possible that we run into a
situation where the function that recreates the Python stack in our TLS
variable after a greenlet switch is called **after** the profile
function has been deactivated. In this case, recreating the Python stack
is wrong as we are no longer tracking POP/PUSH events so when the stack
is inspected later nothing guarantees that the frames are still valid.

Signed-off-by: Pablo Galindo <[email protected]>
  • Loading branch information
pablogsal committed Dec 1, 2024
1 parent aa1b452 commit 1a38fba
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 3 deletions.
1 change: 1 addition & 0 deletions news/700.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a crash when a greenlet switch happens after Memray's profile function has been deactivated or replaced.
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Cython
coverage[toml]
greenlet; python_version < '3.13'
greenlet; python_version < '3.14'
pytest
pytest-cov
ipython
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def build_js_files(self):

test_requires = [
"Cython",
"greenlet; python_version < '3.13'",
"greenlet; python_version < '3.14'",
"pytest",
"pytest-cov",
"ipython",
Expand Down
8 changes: 8 additions & 0 deletions src/memray/_memray/tracking_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,9 +1216,17 @@ Tracker::beginTrackingGreenlets()
void
Tracker::handleGreenletSwitch(PyObject* from, PyObject* to)
{
// We must stop tracking the stack once our trace function is uninstalled.
// Otherwise, we'd keep referencing frames after they're destroyed.
PyThreadState* ts = PyThreadState_Get();
if (ts->c_profilefunc != PyTraceFunction) {
return;
}

// Grab the Tracker lock, as this may need to write pushes/pops.
std::unique_lock<std::mutex> lock(*s_mutex);
RecursionGuard guard;

PythonStackTracker::get().handleGreenletSwitch(from, to);
}

Expand Down
72 changes: 71 additions & 1 deletion tests/integration/test_greenlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tests.utils import filter_relevant_allocations

pytestmark = pytest.mark.skipif(
sys.version_info >= (3, 12), reason="Greenlet does not yet support Python 3.12"
sys.version_info >= (3, 14), reason="Greenlet does not yet support Python 3.14"
)


Expand Down Expand Up @@ -194,3 +194,73 @@ def stack(alloc):
assert vallocs[0].tid != vallocs[1].tid != vallocs[6].tid
assert vallocs[0].tid == vallocs[2].tid
assert vallocs[1].tid == vallocs[3].tid == vallocs[4].tid == vallocs[5].tid


def test_uninstall_profile_in_greenlet(tmpdir):
"""Verify that memray handles profile function changes in greenlets correctly."""
# GIVEN
output = Path(tmpdir) / "test.bin"
subprocess_code = textwrap.dedent(
f"""
import greenlet
import sys
from memray import Tracker
from memray._test import MemoryAllocator
def foo():
bar()
allocator.valloc(1024 * 10)
def bar():
baz()
def baz():
sys.setprofile(None)
other.switch()
def test():
allocator.valloc(1024 * 70)
main_greenlet.switch()
allocator = MemoryAllocator()
output = "{output}"
with Tracker(output):
main_greenlet = greenlet.getcurrent()
other = greenlet.greenlet(test)
foo()
"""
)

# WHEN
subprocess.run([sys.executable, "-Xdev", "-c", subprocess_code], timeout=5)

# THEN
reader = FileReader(output)
records = list(reader.get_allocation_records())
vallocs = [
record
for record in filter_relevant_allocations(records)
if record.allocator == AllocatorType.VALLOC
]

def stack(alloc):
return [frame[0] for frame in alloc.stack_trace()]

# Verify allocations and their stack traces (which should be empty
# because we remove the tracking function)
assert len(vallocs) == 2

assert stack(vallocs[0]) == []
assert vallocs[0].size == 70 * 1024

assert stack(vallocs[1]) == []
assert vallocs[1].size == 10 * 1024

# Verify thread IDs
main_tid = vallocs[0].tid # inner greenlet
outer_tid = vallocs[1].tid # outer greenlet
assert main_tid == outer_tid

0 comments on commit 1a38fba

Please sign in to comment.