Skip to content

Commit

Permalink
WandB in OpenFL
Browse files Browse the repository at this point in the history
  • Loading branch information
CasellaJr committed Nov 8, 2023
1 parent 79b8dbc commit 57c3a3b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"import tqdm\n",
"import wandb\n",
"LOG_WANDB = True\n",
"\n",
"torch.manual_seed(0)\n",
"np.random.seed(0)"
Expand Down Expand Up @@ -384,16 +386,22 @@
" net_model.to(device)\n",
"\n",
" losses = []\n",
" epochs = 1 #change this if you want more epochs per round\n",
"\n",
" for data, target in train_loader:\n",
" data, target = torch.tensor(data).to(device), torch.tensor(\n",
" target).to(device)\n",
" optimizer.zero_grad()\n",
" output = net_model(data)\n",
" loss = loss_fn(output=output, target=target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.detach().cpu().numpy())\n",
" for epoch in range(epochs):\n",
" for data, target in train_loader:\n",
" data, target = torch.tensor(data).to(device), torch.tensor(\n",
" target).to(device)\n",
" optimizer.zero_grad()\n",
" output = net_model(data)\n",
" loss = loss_fn(output=output, target=target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.detach().cpu().numpy())\n",

" if LOG_WANDB:\n",
"wandb.run.summary['step'] = epoch\n",
"wandb.log({'Training loss': np.mean(losses)})\n",
" \n",
" return {'train_loss': np.mean(losses),}\n",
"\n",
Expand Down
18 changes: 18 additions & 0 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from openfl.utilities import TensorKey
from openfl.utilities.logs import write_metric

import wandb
LOG_WANDB = True

class Aggregator:
r"""An Aggregator is the central node in federated learning.
Expand Down Expand Up @@ -55,6 +57,18 @@ def __init__(self,
log_metric_callback=None,
**kwargs):
"""Initialize."""
#INITIALIZE WANDB WITH CORRECT NAME
#The following variable, my_aggregator_name, can be changed to whatever you want. Right now I suppose that the name of the collaborators of the federation is in the format
DATASETNAME_ENV_NUMBER
my_aggregator_name = '_'.join(set(element.split('_')[0] for element in authorized_cols))
if LOG_WANDB:
wandb.init(project="my_project", entity="my_group", group=f"{my_aggregator_name}", tags=["my_tag"],
config={
"num_clients": 4,
"rounds": 100
},
name=f"Aggregator_{my_aggregator_name}"
)
self.round_number = 0
self.single_col_cert_common_name = single_col_cert_common_name

Expand Down Expand Up @@ -837,6 +851,8 @@ def _compute_validation_related_task_metrics(self, task_name):
if agg_function:
self.logger.metric(f'Round {round_number}, aggregator: {task_name} '
f'{agg_function} {agg_tensor_name}:\t{agg_results:f}')
if LOG_WANDB:
wandb.log({f"{task_name} {agg_tensor_name}": float(f"{agg_results}")}, step=round_number)
else:
self.logger.metric(f'Round {round_number}, aggregator: {task_name} '
f'{agg_tensor_name}:\t{agg_results:f}')
Expand Down Expand Up @@ -893,6 +909,8 @@ def _end_of_round_check(self):
# TODO This needs to be fixed!
if self._time_to_quit():
self.logger.info('Experiment Completed. Cleaning up...')
if LOG_WANDB:
wandb.finish()
else:
self.logger.info(f'Starting round {self.round_number}...')

Expand Down
12 changes: 12 additions & 0 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from openfl.protocols import utils
from openfl.utilities import TensorKey

import wandb
LOG_WANDB = True

class DevicePolicy(Enum):
"""Device assignment policy."""
Expand Down Expand Up @@ -133,6 +135,14 @@ def set_available_devices(self, cuda: Tuple[str] = ()):

def run(self):
"""Run the collaborator."""
if LOG_WANDB:
wandb.init(project="my_project", entity="my_group", tags=["my_tags"],
config={
"num_clients": 4,
"rounds": 100,
},
name=self.collaborator_name
)
while True:
tasks, round_number, sleep_time, time_to_quit = self.get_tasks()
if time_to_quit:
Expand All @@ -148,6 +158,8 @@ def run(self):
self.tensor_db.clean_up(self.db_store_rounds)

self.logger.info('End of Federation reached. Exiting...')
if LOG_WANDB:
wandb.finish()

def run_simulation(self):
"""
Expand Down

0 comments on commit 57c3a3b

Please sign in to comment.