diff --git a/openfl-tutorials/interactive_api/PyTorch_TinyImageNet/workspace/pytorch_tinyimagenet.ipynb b/openfl-tutorials/interactive_api/PyTorch_TinyImageNet/workspace/pytorch_tinyimagenet.ipynb index bbfadfd0860..09d1511c698 100644 --- a/openfl-tutorials/interactive_api/PyTorch_TinyImageNet/workspace/pytorch_tinyimagenet.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_TinyImageNet/workspace/pytorch_tinyimagenet.ipynb @@ -43,6 +43,8 @@ "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "import tqdm\n", + "import wandb\n", + "LOG_WANDB = True\n", "\n", "torch.manual_seed(0)\n", "np.random.seed(0)" @@ -384,16 +386,22 @@ " net_model.to(device)\n", "\n", " losses = []\n", + " epochs = 1 #change this if you want more epochs per round\n", "\n", - " for data, target in train_loader:\n", - " data, target = torch.tensor(data).to(device), torch.tensor(\n", - " target).to(device)\n", - " optimizer.zero_grad()\n", - " output = net_model(data)\n", - " loss = loss_fn(output=output, target=target)\n", - " loss.backward()\n", - " optimizer.step()\n", - " losses.append(loss.detach().cpu().numpy())\n", + " for epoch in range(epochs):\n", + " for data, target in train_loader:\n", + " data, target = torch.tensor(data).to(device), torch.tensor(\n", + " target).to(device)\n", + " optimizer.zero_grad()\n", + " output = net_model(data)\n", + " loss = loss_fn(output=output, target=target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " losses.append(loss.detach().cpu().numpy())\n", + + " if LOG_WANDB:\n", + "wandb.run.summary['step'] = epoch\n", + "wandb.log({'Training loss': np.mean(losses)})\n", " \n", " return {'train_loss': np.mean(losses),}\n", "\n", diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index f2d3c17b5fd..b864c720d0a 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -18,6 +18,8 @@ from openfl.utilities import TensorKey from openfl.utilities.logs import write_metric +import wandb +LOG_WANDB = True class Aggregator: r"""An Aggregator is the central node in federated learning. @@ -55,6 +57,18 @@ def __init__(self, log_metric_callback=None, **kwargs): """Initialize.""" + #INITIALIZE WANDB WITH CORRECT NAME + #The following variable, my_aggregator_name, can be changed to whatever you want. Right now I suppose that the name of the collaborators of the federation is in the format +DATASETNAME_ENV_NUMBER + my_aggregator_name = '_'.join(set(element.split('_')[0] for element in authorized_cols)) + if LOG_WANDB: + wandb.init(project="my_project", entity="my_group", group=f"{my_aggregator_name}", tags=["my_tag"], + config={ + "num_clients": 4, + "rounds": 100 + }, + name=f"Aggregator_{my_aggregator_name}" +) self.round_number = 0 self.single_col_cert_common_name = single_col_cert_common_name @@ -837,6 +851,8 @@ def _compute_validation_related_task_metrics(self, task_name): if agg_function: self.logger.metric(f'Round {round_number}, aggregator: {task_name} ' f'{agg_function} {agg_tensor_name}:\t{agg_results:f}') + if LOG_WANDB: + wandb.log({f"{task_name} {agg_tensor_name}": float(f"{agg_results}")}, step=round_number) else: self.logger.metric(f'Round {round_number}, aggregator: {task_name} ' f'{agg_tensor_name}:\t{agg_results:f}') @@ -893,6 +909,8 @@ def _end_of_round_check(self): # TODO This needs to be fixed! if self._time_to_quit(): self.logger.info('Experiment Completed. Cleaning up...') + if LOG_WANDB: + wandb.finish() else: self.logger.info(f'Starting round {self.round_number}...') diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 59193983880..eb346f15ce1 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -14,6 +14,8 @@ from openfl.protocols import utils from openfl.utilities import TensorKey +import wandb +LOG_WANDB = True class DevicePolicy(Enum): """Device assignment policy.""" @@ -133,6 +135,14 @@ def set_available_devices(self, cuda: Tuple[str] = ()): def run(self): """Run the collaborator.""" + if LOG_WANDB: + wandb.init(project="my_project", entity="my_group", tags=["my_tags"], + config={ + "num_clients": 4, + "rounds": 100, + }, + name=self.collaborator_name +) while True: tasks, round_number, sleep_time, time_to_quit = self.get_tasks() if time_to_quit: @@ -148,6 +158,8 @@ def run(self): self.tensor_db.clean_up(self.db_store_rounds) self.logger.info('End of Federation reached. Exiting...') + if LOG_WANDB: + wandb.finish() def run_simulation(self): """