Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #718 WandB in OpenFL and example to change number of epochs per round #895

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"import tqdm\n",
"import wandb\n",
"LOG_WANDB = True\n",
"\n",
"torch.manual_seed(0)\n",
"np.random.seed(0)"
Expand Down Expand Up @@ -384,16 +386,22 @@
" net_model.to(device)\n",
"\n",
" losses = []\n",
" epochs = 1 #change this if you want more epochs per round\n",
"\n",
" for data, target in train_loader:\n",
" data, target = torch.tensor(data).to(device), torch.tensor(\n",
" target).to(device)\n",
" optimizer.zero_grad()\n",
" output = net_model(data)\n",
" loss = loss_fn(output=output, target=target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.detach().cpu().numpy())\n",
" for epoch in range(epochs):\n",
" for data, target in train_loader:\n",
" data, target = torch.tensor(data).to(device), torch.tensor(\n",
" target).to(device)\n",
" optimizer.zero_grad()\n",
" output = net_model(data)\n",
" loss = loss_fn(output=output, target=target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.detach().cpu().numpy())\n",

" if LOG_WANDB:\n",
"wandb.run.summary['step'] = epoch\n",
"wandb.log({'Training loss': np.mean(losses)})\n",
" \n",
" return {'train_loss': np.mean(losses),}\n",
"\n",
Expand Down
17 changes: 17 additions & 0 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from openfl.utilities import TensorKey
from openfl.utilities.logs import write_metric

import wandb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require wandb as an import for OpenFL, which is something we want to avoid.

LOG_WANDB = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be configurable by the definer of the experiment, i.e. pass this through the OpenFL Plan


class Aggregator:
r"""An Aggregator is the central node in federated learning.
Expand Down Expand Up @@ -55,6 +57,17 @@ def __init__(self,
log_metric_callback=None,
**kwargs):
"""Initialize."""
#INITIALIZE WANDB WITH CORRECT NAME
#The following variable, my_aggregator_name, can be changed to whatever you want. Right now I suppose that the name of the collaborators of the federation is in the format DATASETNAME_ENV_NUMBER
my_aggregator_name = '_'.join(set(element.split('_')[0] for element in authorized_cols))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be moved into the if LOG_WANDB block because it's not used elsewhere. Also, it's a bit confusing that the my_aggregator_name variable is a combined list of collaborators. May want to modify this to:

Suggested change
my_aggregator_name = '_'.join(set(element.split('_')[0] for element in authorized_cols))
my_aggregator_name = 'Aggregator_' + '_'.join(set(element.split('_')[0] for element in authorized_cols))

if LOG_WANDB:
wandb.init(project="my_project", entity="my_group", group=f"{my_aggregator_name}", tags=["my_tag"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the project, entity, group, etc. should also be configurable through the plan

config={
"num_clients": 4,
"rounds": 100
Comment on lines +66 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with the wandb configuration, but the same guidance applies here as earlier comments. I would expect num_clients and rounds to be configurable or based on existing parameters used elsewhere

},
name=f"Aggregator_{my_aggregator_name}"
)
self.round_number = 0
self.single_col_cert_common_name = single_col_cert_common_name

Expand Down Expand Up @@ -837,6 +850,8 @@ def _compute_validation_related_task_metrics(self, task_name):
if agg_function:
self.logger.metric(f'Round {round_number}, aggregator: {task_name} '
f'{agg_function} {agg_tensor_name}:\t{agg_results:f}')
if LOG_WANDB:
wandb.log({f"{task_name} {agg_tensor_name}": float(f"{agg_results}")}, step=round_number)
else:
self.logger.metric(f'Round {round_number}, aggregator: {task_name} '
f'{agg_tensor_name}:\t{agg_results:f}')
Expand Down Expand Up @@ -893,6 +908,8 @@ def _end_of_round_check(self):
# TODO This needs to be fixed!
if self._time_to_quit():
self.logger.info('Experiment Completed. Cleaning up...')
if LOG_WANDB:
wandb.finish()
else:
self.logger.info(f'Starting round {self.round_number}...')

Expand Down
12 changes: 12 additions & 0 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from openfl.protocols import utils
from openfl.utilities import TensorKey

import wandb
LOG_WANDB = True
Comment on lines +17 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above. Should only import this if wandb is required in the plan


class DevicePolicy(Enum):
"""Device assignment policy."""
Expand Down Expand Up @@ -133,6 +135,14 @@ def set_available_devices(self, cuda: Tuple[str] = ()):

def run(self):
"""Run the collaborator."""
if LOG_WANDB:
wandb.init(project="my_project", entity="my_group", tags=["my_tags"],
config={
"num_clients": 4,
"rounds": 100,
},
name=self.collaborator_name
Comment on lines +139 to +144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above. Make project, entity, tags configurable. Use client count and rounds from existing plan variables.

)
while True:
tasks, round_number, sleep_time, time_to_quit = self.get_tasks()
if time_to_quit:
Expand All @@ -148,6 +158,8 @@ def run(self):
self.tensor_db.clean_up(self.db_store_rounds)

self.logger.info('End of Federation reached. Exiting...')
if LOG_WANDB:
wandb.finish()

def run_simulation(self):
"""
Expand Down
Loading