-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes #718 WandB in OpenFL and example to change number of epochs per round #895
base: develop
Are you sure you want to change the base?
Conversation
57c3a3b
to
8f31b0b
Compare
Signed-off-by: Bruno Casella <[email protected]>
Signed-off-by: Bruno Casella <[email protected]>
e573b1a
to
a9a1266
Compare
Sorry @psfoley, there was a typo in the aggregator.py. I fixed it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @CasellaJr - thanks for this contribution. There's some changes that are required before this can be approved. Those changes fall into two categories:
wandb
cannot be assumed to be present in all federations. It should be made optional, and ideally this would be configurable in the FL Plan. TheLOG_WANDB
variable you defined I think aims to make it's use optional, but this is not something a user could set without modifyingaggregator.py
orcollaborator.py
directly in the current version. See theget_aggregator
function inplan.py
as an example of how the aggregator object gets constructed, and how the plan.yaml parameters are used to initialize the object.- Hard coded variables. The wandb project name and configuration should be made configurable (again through the plan).
@@ -18,6 +18,8 @@ | |||
from openfl.utilities import TensorKey | |||
from openfl.utilities.logs import write_metric | |||
|
|||
import wandb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would require wandb
as an import for OpenFL, which is something we want to avoid.
@@ -18,6 +18,8 @@ | |||
from openfl.utilities import TensorKey | |||
from openfl.utilities.logs import write_metric | |||
|
|||
import wandb | |||
LOG_WANDB = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be configurable by the definer of the experiment, i.e. pass this through the OpenFL Plan
#The following variable, my_aggregator_name, can be changed to whatever you want. Right now I suppose that the name of the collaborators of the federation is in the format DATASETNAME_ENV_NUMBER | ||
my_aggregator_name = '_'.join(set(element.split('_')[0] for element in authorized_cols)) | ||
if LOG_WANDB: | ||
wandb.init(project="my_project", entity="my_group", group=f"{my_aggregator_name}", tags=["my_tag"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name of the project, entity, group, etc. should also be configurable through the plan
@@ -55,6 +57,17 @@ def __init__(self, | |||
log_metric_callback=None, | |||
**kwargs): | |||
"""Initialize.""" | |||
#INITIALIZE WANDB WITH CORRECT NAME | |||
#The following variable, my_aggregator_name, can be changed to whatever you want. Right now I suppose that the name of the collaborators of the federation is in the format DATASETNAME_ENV_NUMBER | |||
my_aggregator_name = '_'.join(set(element.split('_')[0] for element in authorized_cols)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be moved into the if LOG_WANDB
block because it's not used elsewhere. Also, it's a bit confusing that the my_aggregator_name
variable is a combined list of collaborators. May want to modify this to:
my_aggregator_name = '_'.join(set(element.split('_')[0] for element in authorized_cols)) | |
my_aggregator_name = 'Aggregator_' + '_'.join(set(element.split('_')[0] for element in authorized_cols)) |
"num_clients": 4, | ||
"rounds": 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with the wandb configuration, but the same guidance applies here as earlier comments. I would expect num_clients
and rounds
to be configurable or based on existing parameters used elsewhere
import wandb | ||
LOG_WANDB = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above. Should only import this if wandb
is required in the plan
wandb.init(project="my_project", entity="my_group", tags=["my_tags"], | ||
config={ | ||
"num_clients": 4, | ||
"rounds": 100, | ||
}, | ||
name=self.collaborator_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above. Make project, entity, tags configurable. Use client count and rounds from existing plan variables.
With this PR it is possible to use WandB for tracking metrics of OpenFL experiments. For the Aggregator and each Collaborator (thanks to @Giemp95) will be created a new WandB run, that will be grouped under the name of the Aggregator for the sake of cleanliness.
Moreover, in the PyTorch_TinyImageNet.ipynb you can find how to track the metrics, and how to modify the number of epochs per round.