Skip to content

Commit

Permalink
feat[tui]: textual refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ayshiff committed Dec 14, 2022
1 parent f4f93a8 commit 2c0e749
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 60 deletions.
4 changes: 2 additions & 2 deletions src/memray/commands/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from memray import SocketReader
from memray._errors import MemrayCommandError
from memray.reporters.tui import TUI
from memray.reporters.tui import TUIApp

KEYS = {
"ESC": "\x1b",
Expand Down Expand Up @@ -90,4 +90,4 @@ def start_live_interface(self, port: int) -> None:
if port >= 2**16 or port <= 0:
raise MemrayCommandError(f"Invalid port: {port}", exit_code=1)
with SocketReader(port=port) as reader:
TUI(reader).run()
TUIApp(reader).run()
2 changes: 1 addition & 1 deletion src/memray/reporters/tui.css
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Screen {
TUI {
layout: grid;
grid-size: 1 6;
grid-rows: 2fr 7fr 3fr 70% 1fr 1fr;
Expand Down
158 changes: 103 additions & 55 deletions src/memray/reporters/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@
from rich.markup import escape
from rich.panel import Panel
from rich.progress_bar import ProgressBar

from textual.app import App, ComposeResult, Widget
from textual.widgets import DataTable, Footer, Static, Label
from textual.containers import Container
from textual.reactive import reactive, Reactive
from textual.app import App
from textual.app import ComposeResult
from textual.app import Screen
from textual.app import Widget
from textual.binding import Binding

from memray import AllocationRecord, SocketReader
from textual.containers import Container
from textual.reactive import reactive
from textual.widgets import DataTable
from textual.widgets import Footer
from textual.widgets import Label
from textual.widgets import Static

from memray import AllocationRecord
from memray import SocketReader
from memray._memray import size_fmt

MAX_MEMORY_RATIO = 0.95
Expand Down Expand Up @@ -217,8 +223,14 @@ class Table(Widget):
current_thread = reactive(0)
current_memory_size = reactive(0)

columns = ["Location", "Total Memory", "Total Memory %",
"Own Memory", "Own Memory % ", "Allocation Count"]
columns = [
"Location",
"Total Memory",
"Total Memory %",
"Own Memory",
"Own Memory % ",
"Allocation Count",
]

def __init__(self, native: bool):
super().__init__()
Expand Down Expand Up @@ -288,9 +300,7 @@ def render_table(self, snapshot) -> DataTable:
f"[bold magenta]{escape(location.function)}[/] at "
f"[cyan]{escape(location.file)}[/]"
)
total_color = _size_to_color(
result.total_memory / self.current_memory_size
)
total_color = _size_to_color(result.total_memory / self.current_memory_size)
own_color = _size_to_color(result.own_memory / self.current_memory_size)
allocation_colors = _size_to_color(result.n_allocations / total_allocations)
percent_total = result.total_memory / self.current_memory_size * 100
Expand Down Expand Up @@ -333,18 +343,18 @@ def compose(self) -> ComposeResult:
Label(f"[b]PID[/]: {self.pid}", id="pid"),
Label(id="tid"),
Label(id="samples"),
id="header_metadata_col_1"
id="header_metadata_col_1",
),
Container(
Label(f"[b]CMD[/]: {self.command_line}", id="cmd"),
Label(id="thread"),
Label(id="duration"),
id="header_metadata_col_2"
id="header_metadata_col_2",
),
id="header_metadata"
id="header_metadata",
),
Static(id="panel"),
id="header_container"
id="header_container",
)

def watch_n_samples(self, n_samples: int) -> None:
Expand All @@ -354,7 +364,8 @@ def watch_n_samples(self, n_samples: int) -> None:
def watch_last_update(self, last_update: datetime) -> None:
"""Called when the last_update attribute changes."""
self.query_one("#duration", Label).update(
f"[b]Duration[/]: {(last_update - self.start).total_seconds()} seconds")
f"[b]Duration[/]: {(last_update - self.start).total_seconds()} seconds"
)


class HeapSize(Widget):
Expand All @@ -371,28 +382,34 @@ def compose(self) -> ComposeResult:
)
yield Static(id="progress_bar")

def update_progress_bar(self, current_memory_size: int, max_memory_seen: int) -> None:
def update_progress_bar(
self, current_memory_size: int, max_memory_seen: int
) -> None:
"""Method to update the progress bar."""
self.query_one("#progress_bar", Static).update(ProgressBar(
completed=current_memory_size,
total=max_memory_seen + 1,
complete_style="blue",
))
self.query_one("#progress_bar", Static).update(
ProgressBar(
completed=current_memory_size,
total=max_memory_seen + 1,
complete_style="blue",
)
)

def watch_current_memory_size(self, current_memory_size: int) -> None:
"""Called when the current_memory_size attribute changes."""
self.query_one("#current_memory_size", Label).update(
f"[bold]Current heap size[/]: {size_fmt(current_memory_size)}")
f"[bold]Current heap size[/]: {size_fmt(current_memory_size)}"
)
self.update_progress_bar(current_memory_size, self.max_memory_seen)

def watch_max_memory_seen(self, max_memory_seen: int) -> None:
"""Called when the max_memory_seen attribute changes."""
self.query_one("#max_memory_seen", Label).update(
f"[bold]Max heap size seen[/]: {size_fmt(max_memory_seen)}")
f"[bold]Max heap size seen[/]: {size_fmt(max_memory_seen)}"
)
self.update_progress_bar(self.current_memory_size, max_memory_seen)


class TUI(App):
class TUI(Screen):
"""TUI main application class."""

KEY_TO_COLUMN_NAME = {
Expand Down Expand Up @@ -430,17 +447,13 @@ class TUI(App):
current_memory_size = reactive(0)
graph = reactive(stream.graph)

def __init__(self, reader: SocketReader):
self._reader = reader
def __init__(self, pid: Optional[int], cmd_line: Optional[str], native: bool):
self.pid, self.cmd_line, self.native = pid, cmd_line, native
self._seen_threads: Set[int] = set()
self._max_memory_seen = 0
self.active = True

super().__init__()

def on_mount(self):
self.auto_refresh = 0.1

@property
def current_thread(self) -> int:
return self.threads[self.thread_idx]
Expand All @@ -465,17 +478,18 @@ def action_sort(self, sort_attribute: str) -> None:

def watch_thread_idx(self, thread_idx: int) -> None:
"""Called when the thread_idx attribute changes."""
# Header
self.query_one("#tid", Label).update(f"[b]TID[/]: {hex(self.current_thread)}")
self.query_one("#thread", Label).update(
f"[b]Thread[/] {thread_idx + 1} of {len(self.threads)}")

f"[b]Thread[/] {thread_idx + 1} of {len(self.threads)}"
)
self.query_one(Table).current_thread = self.current_thread

def watch_threads(self, threads: List[int]) -> None:
"""Called when the threads attribute changes."""
self.query_one("#tid", Label).update(f"[b]TID[/]: {hex(self.current_thread)}")
self.query_one("#thread", Label).update(
f"[b]Thread[/] {self.thread_idx + 1} of {len(threads)}")
f"[b]Thread[/] {self.thread_idx + 1} of {len(threads)}"
)

def watch_current_memory_size(self, current_memory_size: int) -> None:
"""Called when the current_memory_size attribute changes."""
Expand All @@ -484,37 +498,36 @@ def watch_current_memory_size(self, current_memory_size: int) -> None:

def watch_graph(self, graph: List[Deque[str]]) -> None:
"""Called when the graph attribute changes to update the header panel."""
self.query_one("#panel", Static).update(Panel(
"\n".join(graph),
title="Memory",
title_align="left",
border_style="green",
expand=False,
))
self.query_one("#panel", Static).update(
Panel(
"\n".join(graph),
title="Memory",
title_align="left",
border_style="green",
expand=False,
)
)

def compose(self) -> ComposeResult:
yield Container(
Label("[b]Memray[/b] live tracking", id="head_title"),
TimeDisplay(id="head_time_display"),
id="head"
id="head",
)
yield Header(pid=self._reader.pid, cmd_line=escape(self._reader.command_line))
yield Header(pid=self.pid, cmd_line=escape(self.cmd_line))
yield HeapSize()
yield Table(native=self._reader.has_native_traces)
yield Table(native=self.native)
yield Label(id="message")
yield Footer()

def _automatic_refresh(self) -> None:
"""Method called every auto_refresh seconds."""
if self.active:
snapshot = list(self._reader.get_current_snapshot(merge_threads=False))
self.update_snapshot(snapshot)
def get_header(self):
return self.query_one("head", Container) + "\n" + self.query_one(Header)

if not self._reader.is_active:
self.active = False
self.query_one("#message", Label).update("[red]Remote has disconnected[/]")
def get_body(self):
return self.query_one(Table)

super()._automatic_refresh()
def get_heap_size(self):
return self.query_one(HeapSize)

def update_snapshot(self, snapshot: Iterable[AllocationRecord]) -> None:
"""Method called to update snapshot."""
Expand Down Expand Up @@ -555,3 +568,38 @@ def update_sort_key(self, col_number: int) -> None:
body = self.query_one(Table)
body.sort_column_id = col_number
body.sort_field_name = self.sort_field_name


class TUIApp(App):
"""TUI main application class."""

CSS_PATH = "tui.css"

def __init__(self, reader: SocketReader):
self._reader = reader
self.active = True
super().__init__()

def on_mount(self):
self.auto_refresh = 0.1
self.push_screen(
TUI(
pid=self._reader.pid,
cmd_line=self._reader.command_line,
native=self._reader.has_native_traces,
)
)

def _automatic_refresh(self) -> None:
"""Method called every auto_refresh seconds."""
if self.active:
snapshot = list(self._reader.get_current_snapshot(merge_threads=False))
self.query_one(TUI).update_snapshot(snapshot)

if not self._reader.is_active:
self.active = False
self.query_one("#message", Label).update(
"[red]Remote has disconnected[/]"
)

super()._automatic_refresh()
25 changes: 23 additions & 2 deletions tests/unit/test_tui_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from rich import print as rprint
from textual.app import App

from memray import AllocatorType
from memray.reporters.tui import TUI
Expand All @@ -20,8 +21,27 @@ def now(cls):
return datetime.datetime(2021, 1, 1)


class MockTUIApp(App):
def __init__(self, pid, cmd_line, native):
self.pid = pid
self.cmd_line = cmd_line
self.native = native
super().__init__()

def on_mount(self):
self.push_screen(
TUI(
pid=self.pid,
cmd_line=self.cmd_line,
native=self.native,
)
)


def make_tui(pid=123, cmd="python3 some_program.py", native=False):
return TUI(pid=pid, cmd_line=cmd, native=native)
tui_app = MockTUIApp(pid=pid, cmd_line=cmd, native=native)
tui_app.run()
return tui_app


@patch("memray.reporters.tui.datetime", FakeDate)
Expand All @@ -38,7 +58,8 @@ def test_pid(self, pid, out_str):
# GIVEN
snapshot = []
output = StringIO()
tui = make_tui(pid=pid, cmd="")
tui_app = make_tui(pid=pid, cmd="")
tui = tui_app.query_one(TUI)

# WHEN
tui.update_snapshot(snapshot)
Expand Down

0 comments on commit 2c0e749

Please sign in to comment.