From 624253a599ef4e76d2e663feb72dc62b6927cbd7 Mon Sep 17 00:00:00 2001 From: "Chaurasiya, Payal" Date: Tue, 19 Nov 2024 20:58:06 -0800 Subject: [PATCH] Put memory under flag Signed-off-by: Chaurasiya, Payal --- openfl/component/aggregator/aggregator.py | 9 ++++++--- openfl/component/collaborator/collaborator.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index ff221fcfe0..c36852c3a9 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -128,7 +128,7 @@ def __init__( ) self._end_of_round_check_done = [False] * rounds_to_train self.stragglers = [] - self.log_memory_usage = log_memory_usage + self.log_memory_usage = log_memory_usage # Flag can be enabled to get memory usage details for ubuntu system self.memory_details = [] self.rounds_to_train = rounds_to_train @@ -1019,8 +1019,11 @@ def _end_of_round_check(self): all_tasks = self.assigner.get_all_tasks_for_round(self.round_number) for task_name in all_tasks: self._compute_validation_related_task_metrics(task_name) - memory_detail = self.get_memory_usage(self.round_number, "aggregator") - self.memory_details.append(memory_detail) + + if self.log_memory_usage: + # This is the place to check the memory usage of the aggregator + memory_detail = self.get_memory_usage(self.round_number, "aggregator") + self.memory_details.append(memory_detail) # Once all of the task results have been processed self._end_of_round_check_done[self.round_number] = True diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 2b851ba7c5..e0925bc063 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -126,7 +126,7 @@ def __init__( self.delta_updates = delta_updates self.client = client - self.log_memory_usage = log_memory_usage + self.log_memory_usage = log_memory_usage # Flag can be enabled to get memory usage details for ubuntu system self.task_config = task_config self.logger = getLogger(__name__)