Skip to content

Commit

Permalink
Merge pull request #708 from epfml/694-decentralized-fail-julien
Browse files Browse the repository at this point in the history
Fix decentralized learning fail
  • Loading branch information
JulienVig authored Jul 23, 2024
2 parents 2746a8d + 81cdae0 commit edbfce9
Show file tree
Hide file tree
Showing 35 changed files with 491 additions and 461 deletions.
15 changes: 6 additions & 9 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,19 @@ async function runUser(
url: URL,
data: data.DataSplit,
): Promise<List<RoundLogs>> {
const client = new clients.federated.FederatedClient(
url,
task,
new aggregators.MeanAggregator(),
);

// force the federated scheme
const disco = new Disco(task, { scheme: "federated", client });
const trainingScheme = task.trainingInformation.scheme
const aggregator = aggregators.getAggregator(task)
const client = clients.getClient(trainingScheme, url, task, aggregator)
const disco = new Disco(task, { scheme: trainingScheme, client });

const logs = List(await arrayFromAsync(disco.trainByRound(data)));
await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish
await disco.close();
return logs;
}

async function main (task: Task, numberOfUsers: number): Promise<void> {
console.log(`Started federated training of ${task.id}`)
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })
const [server, url] = await startServer()

Expand Down
2 changes: 1 addition & 1 deletion discojs/src/aggregator.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { NodeID } from "./client/types.js";

const AGGREGATORS: Set<[name: string, new () => Aggregator]> = Set.of<
new (model?: Model) => Aggregator
>(MeanAggregator, SecureAggregator).map((Aggregator) => [
>(MeanAggregator, SecureAggregator).map((Aggregator) => [ // MeanAggregator waits for 100% of the node's contributions by default
Aggregator.name,
Aggregator,
]);
Expand Down
7 changes: 4 additions & 3 deletions discojs/src/aggregator/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ export abstract class Base<T> {
log (step: AggregationStep, from?: client.NodeID): void {
switch (step) {
case AggregationStep.ADD:
console.log(`> Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`)
console.log(`Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`)
break
case AggregationStep.UPDATE:
if (from === undefined) {
Expand All @@ -139,8 +139,8 @@ export abstract class Base<T> {
}

/**
* Sets the aggregator's TF.js model.
* @param model The new TF.js model
* Sets the aggregator's model.
* @param model The new model
*/
setModel (model: Model): void {
this._model = model
Expand All @@ -151,6 +151,7 @@ export abstract class Base<T> {
* peer/client within the network, whom we are communicating with during this aggregation
* round.
* @param nodeId The node to be added
* @returns True is the node wasn't already in the list of nodes, False if already included
*/
registerNode (nodeId: client.NodeID): boolean {
if (!this.nodes.has(nodeId)) {
Expand Down
70 changes: 48 additions & 22 deletions discojs/src/aggregator/get.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,59 @@
import type { Task } from '../index.js'
import { aggregator } from '../index.js'
import type { Model } from "../index.js";

/**
* Enumeration of the available types of aggregator.
*/
export enum AggregatorChoice {
MEAN,
SECURE,
BANDIT
}
type AggregatorOptions = Partial<{
model: Model,
scheme: Task['trainingInformation']['scheme'], // if undefined, fallback on task.trainingInformation.scheme
roundCutOff: number, // MeanAggregator
threshold: number, // MeanAggregator
thresholdType: 'relative' | 'absolute', // MeanAggregator
}>

/**
* Provides the aggregator object adequate to the given task.
* @param task The task
* Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
* Here is the ordered list of parameters used to define the aggregator and its default behavior:
* task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
*
* If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
* Otherwise, we default to a MeanAggregator for both training schemes.
*
* For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
* Unless specified otherwise, for federated learning or local training the aggregator default to waiting
* for a single contribution to trigger a model update.
* (the server's model update for federated learning or our own contribution if training locally)
* For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update.
*
* @param task The task object associated with the current training session
* @param options Options passed down to the aggregator's constructor
* @returns The aggregator
*/
export function getAggregator (task: Task): aggregator.Aggregator {
const error = new Error('not implemented')
switch (task.trainingInformation.aggregator) {
case AggregatorChoice.MEAN:
return new aggregator.MeanAggregator()
case AggregatorChoice.BANDIT:
throw error
case AggregatorChoice.SECURE:
if (task.trainingInformation.scheme !== 'decentralized') {
export function getAggregator(task: Task, options: AggregatorOptions = {}): aggregator.Aggregator {
const aggregatorType = task.trainingInformation.aggregator ?? 'mean'
const scheme = options.scheme ?? task.trainingInformation.scheme

switch (aggregatorType) {
case 'mean':
if (scheme === 'decentralized') {
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
options = {
model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'relative',
...options
}
} else {
// If scheme == 'federated' then we only expect the server's contribution at each round
// so we set the aggregation threshold to 1 contribution
// If scheme == 'local' then we only expect our own contribution
options = {
model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'absolute',
...options
}
}
return new aggregator.MeanAggregator(options.model, options.roundCutOff, options.threshold, options.thresholdType)
case 'secure':
if (scheme !== 'decentralized') {
throw new Error('secure aggregation is currently supported for decentralized only')
}
return new aggregator.SecureAggregator()
default:
return new aggregator.MeanAggregator()
return new aggregator.SecureAggregator(options.model, task.trainingInformation.maxShareValue)
}
}
2 changes: 1 addition & 1 deletion discojs/src/aggregator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ export { Base as AggregatorBase, AggregationStep } from './base.js'
export { MeanAggregator } from './mean.js'
export { SecureAggregator } from './secure.js'

export { getAggregator, AggregatorChoice } from './get.js'
export { getAggregator } from './get.js'

export type Aggregator = Base<WeightsContainer>
81 changes: 66 additions & 15 deletions discojs/src/aggregator/mean.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,81 @@ import { AggregationStep, Base as Aggregator } from "./base.js";
import type { Model, WeightsContainer, client } from "../index.js";
import { aggregation } from "../index.js";

/** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
type ThresholdType = 'relative' | 'absolute'

/**
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
*
*/
export class MeanAggregator extends Aggregator<WeightsContainer> {
readonly #threshold: number;
readonly #thresholdType: ThresholdType;

/**
* @param threshold - how many contributions for trigger an aggregation step.
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
* - absolute: t > 1, thus requiring t contributions
* Create a mean aggregator that averages all weight updates received when a specified threshold is met.
* By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
* only accepts contributions from the current round (drops contributions from previous rounds).
*
* @param threshold - how many contributions trigger an aggregation step.
* It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions.
* Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`.
* It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions
* Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update),
* set `threshold = 1` and `thresholdType = 'absolute'`
* @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1,
* If `threshold != 1` then the specified thresholdType is ignored and overwritten
* If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution
* if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions,
* @param roundCutoff - from how many past rounds do we still accept contributions.
* If 0 then only accept contributions from the current round,
* if 1 then the current round and the previous one, etc.
*/
// TODO no way to require a single contribution
constructor(model?: Model, roundCutoff = 0, threshold = 1) {
if (threshold <= 0) throw new Error("threshold must be striclty positive");
if (threshold > 1 && !Number.isInteger(threshold))
throw new Error("absolute thresholds must be integeral");

constructor(model?: Model, roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType) {
if (threshold <= 0) throw new Error("threshold must be strictly positive");
if (threshold > 1 && (!Number.isInteger(threshold)))
throw new Error("absolute thresholds must be integral");

super(model, roundCutoff, 1);
this.#threshold = threshold;

if (threshold < 1) {
// Throw exception if threshold and thresholdType are conflicting
if (thresholdType === 'absolute') {
throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`)
}
this.#thresholdType = 'relative'
}
else if (threshold > 1) {
// Throw exception if threshold and thresholdType are conflicting
if (thresholdType === 'relative') {
throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`)
}
this.#thresholdType = 'absolute'
}
// remaining case: threshold == 1
else {
// Print a warning regarding the default behavior when thresholdType is not specified
if (thresholdType === undefined) {
console.warn(
"[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
"To instead wait for a single contribution, set thresholdType = 'absolute'"
)
this.#thresholdType = 'relative'
} else {
this.#thresholdType = thresholdType
}
}
}

/** Checks whether the contributions buffer is full. */
override isFull(): boolean {
const actualThreshold =
this.#threshold <= 1
const thresholdValue =
this.#thresholdType == 'relative'
? this.#threshold * this.nodes.size
: this.#threshold;

return (this.contributions.get(0)?.size ?? 0) >= actualThreshold;
return (this.contributions.get(0)?.size ?? 0) >= thresholdValue;
}

override add(
Expand All @@ -42,8 +90,11 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
if (currentContributions !== 0)
throw new Error("only a single communication round");

if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
return false;
if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) {
if (!this.nodes.has(nodeId)) console.warn("Contribution rejected because node id is not registered")
if (!this.isWithinRoundCutoff(round)) console.warn(`Contribution rejected because round ${round} is not within round cutoff`)
return false;
}

this.log(
this.contributions.hasIn([0, nodeId])
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/aggregator/secure.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ describe("secret shares test", function () {
describe("secure aggregator", () => {
it("behaves as mean aggregator", async () => {
const secureNetwork = setupNetwork(SecureAggregator)
const meanNetwork = setupNetwork(MeanAggregator)
const meanNetwork = setupNetwork(MeanAggregator) // waits for 100% of the nodes' contributions by default

const meanResults = await communicate(
Map(
Expand Down
9 changes: 6 additions & 3 deletions discojs/src/client/base.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import axios from 'axios'
import type { Set } from 'immutable'

import type { Model, Task, WeightsContainer } from '../index.js'
import { serialization } from '../index.js'
Expand Down Expand Up @@ -85,8 +84,12 @@ export abstract class Base {
_round: number,
): Promise<void> {}

get nodes (): Set<NodeID> {
return this.aggregator.nodes
// Number of contributors to a collaborative session
// If decentralized, it should be the number of peers
// If federated, it should the number of participants excluding the server
// If local it should be 1
get nbOfParticipants(): number {
return this.aggregator.nodes.size // overriden by the federated client
}

get ownId(): NodeID {
Expand Down
Loading

0 comments on commit edbfce9

Please sign in to comment.