Skip to content

Commit

Permalink
Docs update
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-hudson committed Mar 10, 2024
1 parent 0942f7c commit 6697605
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions docs/getting_started/strategies/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class MyStrategy(Strategy):
def agg_before_share_params(self):
pass

def agg_after_collect_params(self) -> 'state_dict':
def agg_after_collect_params(self) -> 'params':
pass

def wrk_before_submit_params(self) -> 'state_dict':
def wrk_before_submit_params(self) -> 'params':
pass

def wrk_on_recv_params(self):
Expand Down
28 changes: 14 additions & 14 deletions docs/getting_started/strategies/custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ class FedProx(FedAvg):
"""..."""

def __init__(
self,
mu: float = 0.3,
participation: float = 1.0,
probabilistic: bool = False,
always_include_child_aggregators: bool = True,
seed: int = None,
self,
mu: float = 0.3,
participation: float = 1.0,
probabilistic: bool = False,
always_include_child_aggregators: bool = True,
seed: int = None,
):
"""..."""
super().__init__(
Expand All @@ -32,17 +32,17 @@ class FedProx(FedAvg):
self.mu = mu

def wrk_after_train_step(
self,
state: FloxWorkerState,
loss: torch.Tensor,
**kwargs,
self,
state: FloxWorkerState,
loss: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""..."""
global_model = state.pre_local_train_model
local_model = state.post_local_train_model
global_model = state.global_model
local_model = state.local_model

params = list(local_model.state_dict().values())
params0 = list(global_model.state_dict().values())
params = list(local_model.params().values())
params0 = list(global_model.params().values())

norm = torch.sum(
torch.Tensor(
Expand Down

0 comments on commit 6697605

Please sign in to comment.