Skip to content

Commit

Permalink
Put thread name formatting in a free function
Browse files Browse the repository at this point in the history
Introduce a new `memray.reporters.common` module for utilities needed by
multiple reporters. Replace the `AllocationRecord.pretty_thread_name`
property with a `format_thread_name` function in that new module.

Signed-off-by: Matt Wozniski <[email protected]>
  • Loading branch information
godlygeek authored and pablogsal committed May 30, 2024
1 parent 0943537 commit 896b37d
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 28 deletions.
4 changes: 0 additions & 4 deletions src/memray/_memray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class AllocationRecord:
def native_segment_generation(self) -> int: ...
@property
def thread_name(self) -> str: ...
@property
def pretty_thread_name(self) -> str: ...
def hybrid_stack_trace(
self,
max_stacks: Optional[int] = None,
Expand Down Expand Up @@ -94,8 +92,6 @@ class TemporalAllocationRecord:
def native_segment_generation(self) -> int: ...
@property
def thread_name(self) -> str: ...
@property
def pretty_thread_name(self) -> str: ...
def hybrid_stack_trace(
self,
max_stacks: Optional[int] = None,
Expand Down
14 changes: 0 additions & 14 deletions src/memray/_memray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -309,14 +309,6 @@ cdef class AllocationRecord:
assert self._reader.get() != NULL, "Cannot get thread name without reader."
return self._reader.get().getThreadName(self.tid)

@property
def pretty_thread_name(self):
if self.tid == -1:
return "merged thread"
name = self.thread_name
thread_id = hex(self.tid)
return f"{thread_id} ({name})" if name else f"{thread_id}"

def stack_trace(self, max_stacks=None):
cache_key = ("python", max_stacks)
if cache_key not in self._stack_trace_cache:
Expand Down Expand Up @@ -450,12 +442,6 @@ cdef class TemporalAllocationRecord:
assert self._reader.get() != NULL, "Cannot get thread name without reader."
return self._reader.get().getThreadName(self.tid)

@property
def pretty_thread_name(self):
name = self.thread_name
thread_id = hex(self.tid)
return f"{thread_id} ({name})" if name else f"{thread_id}"

def stack_trace(self, max_stacks=None):
cache_key = ("python", max_stacks)
if cache_key not in self._stack_trace_cache:
Expand Down
14 changes: 14 additions & 0 deletions src/memray/reporters/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Union

from memray._memray import AllocationRecord
from memray._memray import TemporalAllocationRecord


def format_thread_name(
record: Union[AllocationRecord, TemporalAllocationRecord]
) -> str:
if record.tid == -1:
return "merged thread"
name = record.thread_name
thread_id = hex(record.tid)
return f"{thread_id} ({name})" if name else f"{thread_id}"
7 changes: 4 additions & 3 deletions src/memray/reporters/flamegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from memray import Metadata
from memray._memray import Interval
from memray._memray import TemporalAllocationRecord
from memray.reporters.common import format_thread_name
from memray.reporters.frame_tools import StackFrame
from memray.reporters.frame_tools import is_cpython_internal
from memray.reporters.frame_tools import is_frame_from_import_system
Expand Down Expand Up @@ -263,21 +264,21 @@ def _from_any_snapshot(

unique_threads: Set[str] = set()
for record in allocations:
unique_threads.add(record.pretty_thread_name)
unique_threads.add(format_thread_name(record))

record_data: RecordData
if temporal:
assert isinstance(record, TemporalAllocationRecord)
record_data = {
"thread_name": record.pretty_thread_name,
"thread_name": format_thread_name(record),
"intervals": record.intervals,
"size": None,
"n_allocations": None,
}
else:
assert not isinstance(record, TemporalAllocationRecord)
record_data = {
"thread_name": record.pretty_thread_name,
"thread_name": format_thread_name(record),
"intervals": None,
"size": record.size,
"n_allocations": record.n_allocations,
Expand Down
3 changes: 2 additions & 1 deletion src/memray/reporters/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from memray import AllocatorType
from memray import MemorySnapshot
from memray import Metadata
from memray.reporters.common import format_thread_name
from memray.reporters.templates import render_report


Expand Down Expand Up @@ -47,7 +48,7 @@ def from_snapshot(
allocator = AllocatorType(record.allocator)
result.append(
{
"tid": record.pretty_thread_name,
"tid": format_thread_name(record),
"size": record.size,
"allocator": allocator.name.lower(),
"n_allocations": record.n_allocations,
Expand Down
3 changes: 2 additions & 1 deletion src/memray/reporters/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from memray import AllocatorType
from memray import MemorySnapshot
from memray import Metadata
from memray.reporters.common import format_thread_name

Location = Tuple[str, str]

Expand Down Expand Up @@ -117,7 +118,7 @@ def render_as_csv(
record.n_allocations,
record.size,
record.tid,
record.pretty_thread_name,
format_thread_name(record),
"|".join(f"{func};{mod};{line}" for func, mod, line in stack_trace),
]
)
3 changes: 2 additions & 1 deletion src/memray/reporters/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from memray.reporters._textual_hacks import Bindings
from memray.reporters._textual_hacks import redraw_footer
from memray.reporters._textual_hacks import update_key_description
from memray.reporters.common import format_thread_name
from memray.reporters.frame_tools import is_cpython_internal
from memray.reporters.frame_tools import is_frame_from_import_system
from memray.reporters.frame_tools import is_frame_interesting
Expand Down Expand Up @@ -476,7 +477,7 @@ def from_snapshot(
current_frame = current_frame.children[stack_frame]
current_frame.value += size
current_frame.n_allocations += record.n_allocations
current_frame.thread_id = record.pretty_thread_name
current_frame.thread_id = format_thread_name(record)

if index > MAX_STACKS:
break
Expand Down
4 changes: 0 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ class MockAllocationRecord:
_hybrid_stack: Optional[List[Tuple[str, str, int]]] = None
thread_name: str = ""

@property
def pretty_thread_name(self):
return str(hex(self.tid)) if self.tid != -1 else "merged thread"

@staticmethod
def __get_stack_trace(stack, max_stacks):
if max_stacks == 0:
Expand Down

0 comments on commit 896b37d

Please sign in to comment.