-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #708 from epfml/694-decentralized-fail-julien
Fix decentralized learning fail
- Loading branch information
Showing
35 changed files
with
491 additions
and
461 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,59 @@ | ||
import type { Task } from '../index.js' | ||
import { aggregator } from '../index.js' | ||
import type { Model } from "../index.js"; | ||
|
||
/** | ||
* Enumeration of the available types of aggregator. | ||
*/ | ||
export enum AggregatorChoice { | ||
MEAN, | ||
SECURE, | ||
BANDIT | ||
} | ||
type AggregatorOptions = Partial<{ | ||
model: Model, | ||
scheme: Task['trainingInformation']['scheme'], // if undefined, fallback on task.trainingInformation.scheme | ||
roundCutOff: number, // MeanAggregator | ||
threshold: number, // MeanAggregator | ||
thresholdType: 'relative' | 'absolute', // MeanAggregator | ||
}> | ||
|
||
/** | ||
* Provides the aggregator object adequate to the given task. | ||
* @param task The task | ||
* Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters. | ||
* Here is the ordered list of parameters used to define the aggregator and its default behavior: | ||
* task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme | ||
* | ||
* If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values. | ||
* Otherwise, we default to a MeanAggregator for both training schemes. | ||
* | ||
* For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values. | ||
* Unless specified otherwise, for federated learning or local training the aggregator default to waiting | ||
* for a single contribution to trigger a model update. | ||
* (the server's model update for federated learning or our own contribution if training locally) | ||
* For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update. | ||
* | ||
* @param task The task object associated with the current training session | ||
* @param options Options passed down to the aggregator's constructor | ||
* @returns The aggregator | ||
*/ | ||
export function getAggregator (task: Task): aggregator.Aggregator { | ||
const error = new Error('not implemented') | ||
switch (task.trainingInformation.aggregator) { | ||
case AggregatorChoice.MEAN: | ||
return new aggregator.MeanAggregator() | ||
case AggregatorChoice.BANDIT: | ||
throw error | ||
case AggregatorChoice.SECURE: | ||
if (task.trainingInformation.scheme !== 'decentralized') { | ||
export function getAggregator(task: Task, options: AggregatorOptions = {}): aggregator.Aggregator { | ||
const aggregatorType = task.trainingInformation.aggregator ?? 'mean' | ||
const scheme = options.scheme ?? task.trainingInformation.scheme | ||
|
||
switch (aggregatorType) { | ||
case 'mean': | ||
if (scheme === 'decentralized') { | ||
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100% | ||
options = { | ||
model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'relative', | ||
...options | ||
} | ||
} else { | ||
// If scheme == 'federated' then we only expect the server's contribution at each round | ||
// so we set the aggregation threshold to 1 contribution | ||
// If scheme == 'local' then we only expect our own contribution | ||
options = { | ||
model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'absolute', | ||
...options | ||
} | ||
} | ||
return new aggregator.MeanAggregator(options.model, options.roundCutOff, options.threshold, options.thresholdType) | ||
case 'secure': | ||
if (scheme !== 'decentralized') { | ||
throw new Error('secure aggregation is currently supported for decentralized only') | ||
} | ||
return new aggregator.SecureAggregator() | ||
default: | ||
return new aggregator.MeanAggregator() | ||
return new aggregator.SecureAggregator(options.model, task.trainingInformation.maxShareValue) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.