diff --git a/src/memray/_memray.pyi b/src/memray/_memray.pyi index 09c818ab4b..e54736d63b 100644 --- a/src/memray/_memray.pyi +++ b/src/memray/_memray.pyi @@ -133,7 +133,10 @@ class FileReader: @property def metadata(self) -> Metadata: ... def __init__( - self, file_name: Union[str, Path], *, report_progress: bool = False + self, + file_name: Union[str, Path], + *, + report_progress: bool = False, ) -> None: ... def get_allocation_records(self) -> Iterable[AllocationRecord]: ... def get_temporal_allocation_records( diff --git a/src/memray/_memray.pyx b/src/memray/_memray.pyx index eb7c372fe4..b15cd40752 100644 --- a/src/memray/_memray.pyx +++ b/src/memray/_memray.pyx @@ -880,7 +880,7 @@ cdef class FileReader: n_memory_snapshots_approx = 2048 if 0 < stats["start_time"] < stats["end_time"]: n_memory_snapshots_approx = (stats["end_time"] - stats["start_time"]) / 10 - + if n_memory_snapshots_approx > max_memory_records: n_memory_snapshots_approx = max_memory_records self._memory_snapshots.reserve(n_memory_snapshots_approx) @@ -920,7 +920,7 @@ cdef class FileReader: self._memory_snapshots.push_back(reader.getLatestMemorySnapshot()) else: break - + if len(self._memory_snapshots) > max_memory_records: self._memory_snapshot_bucket = len(self._memory_snapshots) // max_memory_records self._memory_snapshots = self._memory_snapshots[::self._memory_snapshot_bucket] diff --git a/tests/integration/test_tracking.py b/tests/integration/test_tracking.py index b6f685ba0d..4bcb7920c4 100644 --- a/tests/integration/test_tracking.py +++ b/tests/integration/test_tracking.py @@ -1679,22 +1679,26 @@ def test_memory_snapshots_limit_when_reading(self, tmp_path): # WHEN with Tracker(output): - allocator.valloc(ALLOC_SIZE) - time.sleep(0.11) - allocator.free() + for _ in range(2): + allocator.valloc(ALLOC_SIZE) + time.sleep(0.11) + allocator.free() - memory_snapshots = list(FileReader(output).get_memory_snapshots()) + reader = FileReader(output) + memory_snapshots = list(reader.get_memory_snapshots()) + temporal_records = list(reader.get_temporal_allocation_records()) assert memory_snapshots - assert all(record.rss > 0 for record in memory_snapshots) - assert any(record.heap >= ALLOC_SIZE for record in memory_snapshots) - assert sorted(memory_snapshots, key=lambda r: r.time) == memory_snapshots - assert all( - _next.time - prev.time >= 10 - for prev, _next in zip(memory_snapshots, memory_snapshots[1:]) - ) + n_snapshots = len(memory_snapshots) + n_temporal_records = len(temporal_records) - memory_snapshots = list(FileReader(output).get_memory_snapshots()) + reader = FileReader(output, max_memory_records=n_snapshots // 2) + memory_snapshots = list(reader.get_memory_snapshots()) + temporal_records = list(reader.get_temporal_allocation_records()) + + assert memory_snapshots + assert len(memory_snapshots) <= n_snapshots // 2 + 1 + assert len(temporal_records) <= n_temporal_records // 2 + 1 def test_temporary_allocations_when_filling_vector_without_preallocating( self, tmp_path