diff --git a/cli/src/cli.ts b/cli/src/cli.ts index f4df93046..d97bde544 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -20,22 +20,19 @@ async function runUser( url: URL, data: data.DataSplit, ): Promise> { - const client = new clients.federated.FederatedClient( - url, - task, - new aggregators.MeanAggregator(), - ); - - // force the federated scheme - const disco = new Disco(task, { scheme: "federated", client }); + const trainingScheme = task.trainingInformation.scheme + const aggregator = aggregators.getAggregator(task) + const client = clients.getClient(trainingScheme, url, task, aggregator) + const disco = new Disco(task, { scheme: trainingScheme, client }); const logs = List(await arrayFromAsync(disco.trainByRound(data))); + await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish await disco.close(); return logs; } async function main (task: Task, numberOfUsers: number): Promise { - console.log(`Started federated training of ${task.id}`) + console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`) console.log({ args }) const [server, url] = await startServer() diff --git a/discojs/src/aggregator.spec.ts b/discojs/src/aggregator.spec.ts index 7a1e378e9..7101dda1e 100644 --- a/discojs/src/aggregator.spec.ts +++ b/discojs/src/aggregator.spec.ts @@ -11,7 +11,7 @@ import { NodeID } from "./client/types.js"; const AGGREGATORS: Set<[name: string, new () => Aggregator]> = Set.of< new (model?: Model) => Aggregator ->(MeanAggregator, SecureAggregator).map((Aggregator) => [ +>(MeanAggregator, SecureAggregator).map((Aggregator) => [ // MeanAggregator waits for 100% of the node's contributions by default Aggregator.name, Aggregator, ]); diff --git a/discojs/src/aggregator/base.ts b/discojs/src/aggregator/base.ts index f9b6722ee..e783723cc 100644 --- a/discojs/src/aggregator/base.ts +++ b/discojs/src/aggregator/base.ts @@ -119,7 +119,7 @@ export abstract class Base { log (step: AggregationStep, from?: client.NodeID): void { switch (step) { case AggregationStep.ADD: - console.log(`> Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`) + console.log(`Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`) break case AggregationStep.UPDATE: if (from === undefined) { @@ -139,8 +139,8 @@ export abstract class Base { } /** - * Sets the aggregator's TF.js model. - * @param model The new TF.js model + * Sets the aggregator's model. + * @param model The new model */ setModel (model: Model): void { this._model = model @@ -151,6 +151,7 @@ export abstract class Base { * peer/client within the network, whom we are communicating with during this aggregation * round. * @param nodeId The node to be added + * @returns True is the node wasn't already in the list of nodes, False if already included */ registerNode (nodeId: client.NodeID): boolean { if (!this.nodes.has(nodeId)) { diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts index b9577ca48..c54cc819a 100644 --- a/discojs/src/aggregator/get.ts +++ b/discojs/src/aggregator/get.ts @@ -1,33 +1,59 @@ import type { Task } from '../index.js' import { aggregator } from '../index.js' +import type { Model } from "../index.js"; -/** - * Enumeration of the available types of aggregator. - */ -export enum AggregatorChoice { - MEAN, - SECURE, - BANDIT -} +type AggregatorOptions = Partial<{ + model: Model, + scheme: Task['trainingInformation']['scheme'], // if undefined, fallback on task.trainingInformation.scheme + roundCutOff: number, // MeanAggregator + threshold: number, // MeanAggregator + thresholdType: 'relative' | 'absolute', // MeanAggregator +}> /** - * Provides the aggregator object adequate to the given task. - * @param task The task + * Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters. + * Here is the ordered list of parameters used to define the aggregator and its default behavior: + * task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme + * + * If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values. + * Otherwise, we default to a MeanAggregator for both training schemes. + * + * For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values. + * Unless specified otherwise, for federated learning or local training the aggregator default to waiting + * for a single contribution to trigger a model update. + * (the server's model update for federated learning or our own contribution if training locally) + * For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update. + * + * @param task The task object associated with the current training session + * @param options Options passed down to the aggregator's constructor * @returns The aggregator */ -export function getAggregator (task: Task): aggregator.Aggregator { - const error = new Error('not implemented') - switch (task.trainingInformation.aggregator) { - case AggregatorChoice.MEAN: - return new aggregator.MeanAggregator() - case AggregatorChoice.BANDIT: - throw error - case AggregatorChoice.SECURE: - if (task.trainingInformation.scheme !== 'decentralized') { +export function getAggregator(task: Task, options: AggregatorOptions = {}): aggregator.Aggregator { + const aggregatorType = task.trainingInformation.aggregator ?? 'mean' + const scheme = options.scheme ?? task.trainingInformation.scheme + + switch (aggregatorType) { + case 'mean': + if (scheme === 'decentralized') { + // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100% + options = { + model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'relative', + ...options + } + } else { + // If scheme == 'federated' then we only expect the server's contribution at each round + // so we set the aggregation threshold to 1 contribution + // If scheme == 'local' then we only expect our own contribution + options = { + model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'absolute', + ...options + } + } + return new aggregator.MeanAggregator(options.model, options.roundCutOff, options.threshold, options.thresholdType) + case 'secure': + if (scheme !== 'decentralized') { throw new Error('secure aggregation is currently supported for decentralized only') } - return new aggregator.SecureAggregator() - default: - return new aggregator.MeanAggregator() + return new aggregator.SecureAggregator(options.model, task.trainingInformation.maxShareValue) } } diff --git a/discojs/src/aggregator/index.ts b/discojs/src/aggregator/index.ts index 0913c17e1..868278785 100644 --- a/discojs/src/aggregator/index.ts +++ b/discojs/src/aggregator/index.ts @@ -5,6 +5,6 @@ export { Base as AggregatorBase, AggregationStep } from './base.js' export { MeanAggregator } from './mean.js' export { SecureAggregator } from './secure.js' -export { getAggregator, AggregatorChoice } from './get.js' +export { getAggregator } from './get.js' export type Aggregator = Base diff --git a/discojs/src/aggregator/mean.ts b/discojs/src/aggregator/mean.ts index ba234d72a..31fcbdf9d 100644 --- a/discojs/src/aggregator/mean.ts +++ b/discojs/src/aggregator/mean.ts @@ -4,33 +4,81 @@ import { AggregationStep, Base as Aggregator } from "./base.js"; import type { Model, WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; -/** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */ +type ThresholdType = 'relative' | 'absolute' + +/** + * Mean aggregator whose aggregation step consists in computing the mean of the received weights. + * + */ export class MeanAggregator extends Aggregator { readonly #threshold: number; + readonly #thresholdType: ThresholdType; /** - * @param threshold - how many contributions for trigger an aggregation step. - * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions - * - absolute: t > 1, thus requiring t contributions + * Create a mean aggregator that averages all weight updates received when a specified threshold is met. + * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that + * only accepts contributions from the current round (drops contributions from previous rounds). + * + * @param threshold - how many contributions trigger an aggregation step. + * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions. + * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`. + * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions + * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update), + * set `threshold = 1` and `thresholdType = 'absolute'` + * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1, + * If `threshold != 1` then the specified thresholdType is ignored and overwritten + * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution + * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions, + * @param roundCutoff - from how many past rounds do we still accept contributions. + * If 0 then only accept contributions from the current round, + * if 1 then the current round and the previous one, etc. */ - // TODO no way to require a single contribution - constructor(model?: Model, roundCutoff = 0, threshold = 1) { - if (threshold <= 0) throw new Error("threshold must be striclty positive"); - if (threshold > 1 && !Number.isInteger(threshold)) - throw new Error("absolute thresholds must be integeral"); - + constructor(model?: Model, roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType) { + if (threshold <= 0) throw new Error("threshold must be strictly positive"); + if (threshold > 1 && (!Number.isInteger(threshold))) + throw new Error("absolute thresholds must be integral"); + + super(model, roundCutoff, 1); this.#threshold = threshold; + + if (threshold < 1) { + // Throw exception if threshold and thresholdType are conflicting + if (thresholdType === 'absolute') { + throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`) + } + this.#thresholdType = 'relative' + } + else if (threshold > 1) { + // Throw exception if threshold and thresholdType are conflicting + if (thresholdType === 'relative') { + throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`) + } + this.#thresholdType = 'absolute' + } + // remaining case: threshold == 1 + else { + // Print a warning regarding the default behavior when thresholdType is not specified + if (thresholdType === undefined) { + console.warn( + "[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " + + "To instead wait for a single contribution, set thresholdType = 'absolute'" + ) + this.#thresholdType = 'relative' + } else { + this.#thresholdType = thresholdType + } + } } /** Checks whether the contributions buffer is full. */ override isFull(): boolean { - const actualThreshold = - this.#threshold <= 1 + const thresholdValue = + this.#thresholdType == 'relative' ? this.#threshold * this.nodes.size : this.#threshold; - return (this.contributions.get(0)?.size ?? 0) >= actualThreshold; + return (this.contributions.get(0)?.size ?? 0) >= thresholdValue; } override add( @@ -42,8 +90,11 @@ export class MeanAggregator extends Aggregator { if (currentContributions !== 0) throw new Error("only a single communication round"); - if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) - return false; + if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) { + if (!this.nodes.has(nodeId)) console.warn("Contribution rejected because node id is not registered") + if (!this.isWithinRoundCutoff(round)) console.warn(`Contribution rejected because round ${round} is not within round cutoff`) + return false; + } this.log( this.contributions.hasIn([0, nodeId]) diff --git a/discojs/src/aggregator/secure.spec.ts b/discojs/src/aggregator/secure.spec.ts index b4224100f..8888cad39 100644 --- a/discojs/src/aggregator/secure.spec.ts +++ b/discojs/src/aggregator/secure.spec.ts @@ -58,7 +58,7 @@ describe("secret shares test", function () { describe("secure aggregator", () => { it("behaves as mean aggregator", async () => { const secureNetwork = setupNetwork(SecureAggregator) - const meanNetwork = setupNetwork(MeanAggregator) + const meanNetwork = setupNetwork(MeanAggregator) // waits for 100% of the nodes' contributions by default const meanResults = await communicate( Map( diff --git a/discojs/src/client/base.ts b/discojs/src/client/base.ts index 1bfdb4f19..765a90efe 100644 --- a/discojs/src/client/base.ts +++ b/discojs/src/client/base.ts @@ -1,5 +1,4 @@ import axios from 'axios' -import type { Set } from 'immutable' import type { Model, Task, WeightsContainer } from '../index.js' import { serialization } from '../index.js' @@ -85,8 +84,12 @@ export abstract class Base { _round: number, ): Promise {} - get nodes (): Set { - return this.aggregator.nodes + // Number of contributors to a collaborative session + // If decentralized, it should be the number of peers + // If federated, it should the number of participants excluding the server + // If local it should be 1 + get nbOfParticipants(): number { + return this.aggregator.nodes.size // overriden by the federated client } get ownId(): NodeID { diff --git a/discojs/src/client/decentralized/base.ts b/discojs/src/client/decentralized/base.ts index a7118d7be..101a472be 100644 --- a/discojs/src/client/decentralized/base.ts +++ b/discojs/src/client/decentralized/base.ts @@ -20,80 +20,20 @@ export class Base extends Client { */ private pool?: PeerPool private connections?: Map - - /** - * Send message to server that this client is ready for the next training round. - */ - private async waitForPeers (round: number): Promise> { - console.info(`[${this.ownId}] is ready for round`, round) - - // Broadcast our readiness - const readyMessage: messages.PeerIsReady = { type: type.PeerIsReady } - - if (this.server === undefined) { - throw new Error('server undefined, could not connect peers') - } - this.server.send(readyMessage) - - // Wait for peers to be connected before sending any update information - try { - const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound, undefined, "Timeout waiting for the round's peer list") - if (this.nodes.size > 0) { - throw new Error('got new peer list from server but was already received for this round') - } - - const peers = Set(receivedMessage.peers) - console.info(`[${this.ownId}] received peers for round:`, peers.toJS()) - if (this.ownId !== undefined && peers.has(this.ownId)) { - throw new Error('received peer list contains our own id') - } - - this.aggregator.setNodes(peers.add(this.ownId)) - - if (this.pool === undefined) { - throw new Error('waiting for peers but peer pool is undefined') - } - - const connections = await this.pool.getPeers( - peers, - this.server, - // Init receipt of peers weights - (conn) => { this.receivePayloads(conn, round) } - ) - - console.info(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS()) - return connections - } catch (e) { - console.error(e) - this.aggregator.setNodes(Set(this.ownId)) - return Map() - } + + // Used to handle timeouts and promise resolving after calling disconnect + private get isDisconnected() : boolean { + return this._server === undefined } - - protected sendMessagetoPeer (peer: PeerConnection, msg: messages.PeerMessage): void { - console.info(`[${this.ownId}] send message to peer`, msg.peer, msg) - peer.send(msg) - } - + /** - * Creation of the WebSocket for the server, connection of client to that WebSocket, - * deals with message reception from the decentralized client's perspective (messages received by client). + * Public method called by disco.ts when starting training. This method sends + * a message to the server asking to join the task and be assigned a client ID. + * + * The peer also establishes a WebSocket connection with the server to then + * create peer-to-peer WebRTC connections with peers. The server is used to exchange + * peers network information. */ - private async connectServer (url: URL): Promise { - const server: EventConnection = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer) - - server.on(type.SignalForPeer, (event) => { - console.info(`[${this.ownId}] received signal from`, event.peer) - - if (this.pool === undefined) { - throw new Error('received signal but peer pool is undefined') - } - this.pool.signal(event.peer, event.signal) - }) - - return server - } - async connect (): Promise { const serverURL = new URL('', this.url.href) switch (this.url.protocol) { @@ -114,9 +54,9 @@ export class Base extends Client { type: type.ClientConnected } this.server.send(msg) - + const peerIdMsg = await waitMessage(this.server, type.AssignNodeID) - console.info(`[${peerIdMsg.id}] assigned id generated by server`) + console.log(`[${peerIdMsg.id}] assigned id generated by server`) if (this._ownId !== undefined) { throw new Error('received id from server but was already received') @@ -125,6 +65,25 @@ export class Base extends Client { this.pool = new PeerPool(peerIdMsg.id) } + /** + * Create a WebSocket connection with the server + * The client then waits for the server to forward it other client's network information. + * Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection. + */ + private async connectServer (url: URL): Promise { + const server: EventConnection = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer) + + server.on(type.SignalForPeer, (event) => { + if (this.pool === undefined) { + throw new Error('received signal but peer pool is undefined') + } + // Create a WebRTC connection with the peer + this.pool.signal(event.peer, event.signal) + }) + + return server + } + async disconnect (): Promise { // Disconnect from peers await this.pool?.shutdown() @@ -134,26 +93,108 @@ export class Base extends Client { const peers = this.connections.keySeq().toSet() this.aggregator.setNodes(this.aggregator.nodes.subtract(peers)) } - // Disconnect from server await this.server?.disconnect() this._server = undefined this._ownId = undefined - + return Promise.resolve() } + /** + * At the beginning of a round, each peer tells the server it is ready to proceed + * The server answers with the list of all peers connected for the round + * Given the list, the peers then create peer-to-peer connections with each other. + * When connected, one peer creates a promise for every other peer's weight update + * and waits for it to resolve. + * + */ override async onRoundBeginCommunication ( _: WeightsContainer, round: number, ): Promise { + if (this.server === undefined) { + throw new Error("peer's server is undefined, make sure to call `client.connect()` first") + } if (this.pool === undefined) { + throw new Error('peer pool is undefined, make sure to call `client.connect()` first') + } // Reset peers list at each round of training to make sure client works with an updated peers // list, maintained by the server. Adds any received weights to the aggregator. - this.connections = await this.waitForPeers(round) + // this.connections = await this.waitForPeers(round) + + // Tell the server we are ready for the next round + const readyMessage: messages.PeerIsReady = { type: type.PeerIsReady } + + this.server.send(readyMessage) + + // Wait for the server to answer with the list of peers for the round + try { + const receivedMessage = await waitMessageWithTimeout( + this.server, type.PeersForRound, undefined, + "Timeout waiting for the round's peer list" + ) + + const peers = Set(receivedMessage.peers) + + if (this.ownId !== undefined && peers.has(this.ownId)) { + throw new Error('received peer list contains our own id') + } + // Store the list of peers for the current round including ourselves + this.aggregator.setNodes(peers.add(this.ownId)) + + // Initiate peer to peer connections with each peer + // When connected, create a promise waiting for each peer's round contribution + const connections = await this.pool.getPeers( + peers, + this.server, + // Init receipt of peers weights + // this awaits the peer's weight update and adds it to + // our aggregator upon reception + (conn) => { this.receivePayloads(conn, round) } + ) + + console.log(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS()) + this.connections = connections + } catch (e) { + console.error(e) + this.aggregator.setNodes(Set(this.ownId)) + this.connections = Map() + } + // Store the promise for the current round's aggregation result. + // We will await for it to resolve at the end of the round when exchanging weight updates. this.aggregationResult = this.aggregator.receiveResult() } + /** + * At each communication rounds, awaits peers contributions and add them to the client's aggregator. + * This method is used as callback by getPeers when connecting to the rounds' peers + * @param connections + * @param round + */ + private receivePayloads (connections: Map, round: number): void { + connections.forEach(async (connection, peerId) => { + let currentCommunicationRounds = 0 + console.log(`waiting for peer ${peerId}`) + do { + try { + const message = await waitMessageWithTimeout(connection, type.Payload, + 60_000, "Timeout waiting for a contribution from peer " + peerId) + const decoded = serialization.weights.decode(message.payload) + + if (!this.aggregator.add(peerId, decoded, round, message.round)) { + console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`) + } + } catch (e) { + if (this.isDisconnected) { + return + } + console.error(e instanceof Error ? e.message : e) + } + } while (++currentCommunicationRounds < this.aggregator.communicationRounds) + }) + } + override async onRoundEndCommunication ( weights: WeightsContainer, round: number, @@ -174,15 +215,17 @@ export class Base extends Client { if (id === this.ownId) { this.aggregator.add(this.ownId, payload, round, r) } else { - const connection = this.connections?.get(id) - if (connection !== undefined) { + const peer = this.connections?.get(id) + if (peer !== undefined) { const encoded = await serialization.weights.encode(payload) - this.sendMessagetoPeer(connection, { + const msg: messages.PeerMessage = { type: type.Payload, peer: id, round: r, payload: encoded - }) + } + peer.send(msg) + console.log(`[${this.ownId}] send weight update to peer`, msg.peer, msg) } } })) @@ -193,13 +236,20 @@ export class Base extends Client { if (this.aggregationResult === undefined) { throw new TypeError('aggregation result promise is undefined') } - // Wait for aggregation before proceeding to the next communication round. // The current result will be used as payload for the eventual next communication round. - result = await Promise.race([ - this.aggregationResult, - timeout(undefined, "Timeout waiting on the aggregation result promise to resolve") - ]) + try { + result = await Promise.race([ + this.aggregationResult, + timeout(undefined, "Timeout waiting on the aggregation result promise to resolve") + ]) + } catch (e) { + if (this.isDisconnected) { + return + } + console.error(e) + break + } // There is at least one communication round remaining if (r < this.aggregator.communicationRounds - 1) { @@ -211,24 +261,4 @@ export class Base extends Client { // Reset the peers list for the next round this.aggregator.resetNodes() } - - private receivePayloads (connections: Map, round: number): void { - console.info(`[${this.ownId}] Accepting new contributions for round ${round}`) - connections.forEach(async (connection, peerId) => { - let receivedPayloads = 0 - do { - try { - const message = await waitMessageWithTimeout(connection, type.Payload, - undefined, "Timeout waiting for a contribution from peer " + peerId) - const decoded = serialization.weights.decode(message.payload) - - if (!this.aggregator.add(peerId, decoded, round, message.round)) { - console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`) - } - } catch (e) { - console.error(e instanceof Error ? e.message : e) - } - } while (++receivedPayloads < this.aggregator.communicationRounds) - }) - } } diff --git a/discojs/src/client/decentralized/peer.ts b/discojs/src/client/decentralized/peer.ts index 573c8b634..43b7d1857 100644 --- a/discojs/src/client/decentralized/peer.ts +++ b/discojs/src/client/decentralized/peer.ts @@ -96,15 +96,7 @@ export class Peer { if (this.bufferSize === undefined) { throw new Error('chunk without known buffer size') } - - // in the perfect world of bug-free implementations - // we would return this.bufferSize - // sadly, we are not there yet - // - // based on MDN, taking 16K seems to be a pretty safe - // and widely supported buffer size - - return 16 * (1 << 10) + return this.bufferSize } private chunk (b: Buffer): Seq.Indexed { @@ -129,7 +121,8 @@ export class Peer { const totalChunkCount = 1 + tail.count() if (totalChunkCount > 0xFF) { - throw new Error('too big message to even chunk it') + throw new Error(`The payload is too big: ${totalChunkCount * this.maxChunkSize} bytes > 255,` + + ' consider reducing the model size or increasing the chunk size') } const firstChunk = Buffer.alloc( @@ -164,7 +157,7 @@ export class Peer { } signal (signal: SignalData): void { - // extract max buffer size + // extract max buffer size from the signal if (signal.type === 'offer' || signal.type === 'answer') { if (signal.sdp === undefined) { throw new Error('signal answer|offer without session description') @@ -201,9 +194,8 @@ export class Peer { if (!Buffer.isBuffer(data) || data.length < HEADER_SIZE) { throw new Error('received invalid message type') } - - const messageID: MessageID = data.readUint16BE() - const chunkID: ChunkID = data.readUint8(2) + const messageID: MessageID = data.readUInt16BE() //readUint16BE (case sensitive) fails at runtime + const chunkID: ChunkID = data.readUInt8(2) // same for readUint8 const received = this.receiving.get(messageID, { total: undefined, @@ -228,7 +220,7 @@ export class Peer { throw new Error('first header received twice') } - const readTotal = data.readUint8(3) + const readTotal = data.readUInt8(3) total = readTotal chunk = Buffer.alloc(data.length - FIRST_HEADER_SIZE) data.copy(chunk, 0, FIRST_HEADER_SIZE) diff --git a/discojs/src/client/decentralized/peer_pool.ts b/discojs/src/client/decentralized/peer_pool.ts index 6bb58c142..ca2924d50 100644 --- a/discojs/src/client/decentralized/peer_pool.ts +++ b/discojs/src/client/decentralized/peer_pool.ts @@ -14,9 +14,13 @@ export class PeerPool { ) {} async shutdown (): Promise { - console.info(`[${this.id}] shutdown their peers`) + console.info(`[${this.id}] is shutting down all its connections`) - await Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect())) + // Add a timeout o.w. the promise hangs forever if the other peer is already disconnected + await Promise.race([ + Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect())), + new Promise((res, _) => setTimeout(res, 1000)) // Wait for other peers to finish + ]) this.peers = Map() } diff --git a/discojs/src/client/event_connection.ts b/discojs/src/client/event_connection.ts index 4b45f2726..ad607df18 100644 --- a/discojs/src/client/event_connection.ts +++ b/discojs/src/client/event_connection.ts @@ -61,7 +61,7 @@ export class PeerConnection extends EventEmitter<{ [K in type]: NarrowMessage this.emit(msg.type, msg) }) - this.peer.on('close', () => { console.warn('peer', this.peer.id, 'closed connection') }) + this.peer.on('close', () => { console.warn('From', this._ownId, ': peer', this.peer.id, 'closed connection') }) await new Promise((resolve) => { this.peer.on('connect', resolve) @@ -79,7 +79,7 @@ export class PeerConnection extends EventEmitter<{ [K in type]: NarrowMessage this.peer.send(msgpack.encode(msg)) } - async disconnect (): Promise { + async disconnect(): Promise { await this.peer.destroy() } } diff --git a/discojs/src/client/federated/base.ts b/discojs/src/client/federated/base.ts index f4581b316..1e981aa44 100644 --- a/discojs/src/client/federated/base.ts +++ b/discojs/src/client/federated/base.ts @@ -1,12 +1,7 @@ -import { Map } from "immutable"; - import { serialization, - type MetadataKey, - type MetadataValue, type WeightsContainer, } from "../../index.js"; -import { type NodeID } from "../types.js"; import { Base as Client } from "../base.js"; import { type, type ClientConnected } from "../messages.js"; import { @@ -27,10 +22,15 @@ export class Base extends Client { * by this client class, the server is the only node which we are communicating with. */ public static readonly SERVER_NODE_ID = "federated-server-node-id"; - /** - * Map of metadata values for each node id. - */ - private metadataMap?: Map; + + // Total number of other federated contributors, including this client, excluding the server + // E.g., if 3 users are training a federated model, nbOfParticipants is 3 + #nbOfParticipants: number = 1; + + // the number of participants excluding the server + override get nbOfParticipants(): number { + return this.#nbOfParticipants + } /** * Opens a new WebSocket connection with the server and listens to new messages over the channel @@ -94,86 +94,6 @@ export class Base extends Client { return Promise.resolve(); } - /** - * Send a message containing our local weight updates to the federated server. - * And waits for the server to reply with the most recent aggregated weights - * @param payload The weight updates to send - */ - private async sendPayloadAndReceiveResult( - payload: WeightsContainer, - ): Promise { - const msg: messages.SendPayload = { - type: type.SendPayload, - payload: await serialization.weights.encode(payload), - round: this.aggregator.round, - }; - this.server.send(msg); - // It is important than the client immediately awaits the server result or it may miss it - return await this.receiveResult(); - } - - /** - * Waits for the server's result for its current (most recent) round and add it to our aggregator. - * Updates the aggregator's round if it's behind the server's. - */ - private async receiveResult(): Promise { - try { - const { payload, round } = await waitMessageWithTimeout( - this.server, - type.ReceiveServerPayload, - ); - const serverRound = round; - - // Store the server result only if it is not stale - if (this.aggregator.round <= round) { - const serverResult = serialization.weights.decode(payload); - // Update the local round to match the server's - if (this.aggregator.round < serverRound) { - this.aggregator.setRound(serverRound); - } - return serverResult; - } - } catch (e) { - console.error(e); - } - } - - /** - * Fetch the metadata values maintained by the federated server, for a given metadata key. - * The values are indexed by node id. - * @param key The metadata key - * @returns The map of node id to metadata value - */ - async receiveMetadataMap( - key: MetadataKey, - ): Promise | undefined> { - this.metadataMap = undefined; - - const msg: messages.ReceiveServerMetadata = { - type: type.ReceiveServerMetadata, - taskId: this.task.id, - nodeId: this.ownId, - round: this.aggregator.round, - key, - }; - - this.server.send(msg); - - const received = await waitMessageWithTimeout( - this.server, - type.ReceiveServerMetadata, - ); - if (received.metadataMap !== undefined) { - this.metadataMap = Map( - received.metadataMap.filter(([_, v]) => v !== undefined) as Array< - [NodeID, MetadataValue] - >, - ); - } - - return this.metadataMap; - } - override onRoundBeginCommunication(): Promise { // Prepare the result promise for the incoming round this.aggregationResult = this.aggregator.receiveResult(); @@ -212,4 +132,44 @@ export class Base extends Client { this.aggregator.nextRound(); } } + + /** + * Send a message containing our local weight updates to the federated server. + * And waits for the server to reply with the most recent aggregated weights + * @param payload The weight updates to send + */ + private async sendPayloadAndReceiveResult( + payload: WeightsContainer, + ): Promise { + const msg: messages.SendPayload = { + type: type.SendPayload, + payload: await serialization.weights.encode(payload), + round: this.aggregator.round, + }; + this.server.send(msg); + + // Waits for the server's result for its current (most recent) round and add it to our aggregator. + // Updates the aggregator's round if it's behind the server's. + try { + // It is important than the client immediately awaits the server result or it may miss it + const { payload, round, nbOfParticipants } = await waitMessageWithTimeout( + this.server, + type.ReceiveServerPayload, + ); + const serverRound = round; + this.#nbOfParticipants = nbOfParticipants // Save the current participants + + // Store the server result only if it is not stale + if (this.aggregator.round <= round) { + const serverResult = serialization.weights.decode(payload); + // Update the local round to match the server's + if (this.aggregator.round < serverRound) { + this.aggregator.setRound(serverRound); + } + return serverResult; + } + } catch (e) { + console.error(e); + } + } } diff --git a/discojs/src/client/federated/messages.ts b/discojs/src/client/federated/messages.ts index a18dc1565..c55c0dada 100644 --- a/discojs/src/client/federated/messages.ts +++ b/discojs/src/client/federated/messages.ts @@ -1,4 +1,3 @@ -import { type client, type MetadataKey, type MetadataValue } from '../../index.js' import { type weights } from '../../serialization/index.js' import { type, hasMessageType, type AssignNodeID, type ClientConnected } from '../messages.js' @@ -7,7 +6,6 @@ export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | - ReceiveServerMetadata | AssignNodeID export interface SendPayload { @@ -18,15 +16,8 @@ export interface SendPayload { export interface ReceiveServerPayload { type: type.ReceiveServerPayload payload: weights.Encoded - round: number -} -export interface ReceiveServerMetadata { - type: type.ReceiveServerMetadata - nodeId: client.NodeID - taskId: string - round: number - key: MetadataKey - metadataMap?: Array<[client.NodeID, MetadataValue | undefined]> + round: number, + nbOfParticipants: number // number of peers contributing to a federated training } export function isMessageFederated (raw: unknown): raw is MessageFederated { @@ -38,7 +29,6 @@ export function isMessageFederated (raw: unknown): raw is MessageFederated { case type.ClientConnected: case type.SendPayload: case type.ReceiveServerPayload: - case type.ReceiveServerMetadata: case type.AssignNodeID: return true } diff --git a/discojs/src/client/index.ts b/discojs/src/client/index.ts index c657140b7..6a88e6c0c 100644 --- a/discojs/src/client/index.ts +++ b/discojs/src/client/index.ts @@ -6,6 +6,6 @@ export * as aggregator from '../aggregator/index.js' export * as decentralized from './decentralized/index.js' export * as federated from './federated/index.js' export * as messages from './messages.js' -export * as utils from './utils.js' +export { getClient, timeout } from './utils.js' export { Local } from './local.js' diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index b5712c349..26f85c264 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -3,19 +3,23 @@ import type * as federated from './federated/messages.js' import { type NodeID } from './types.js' export enum type { - ClientConnected, + // Sent from client to server as first point of contact to join a task. + // The server answers with an node id in a AssignNodeID message + ClientConnected, + // When a user joins a task with a ClientConnected message, the server + // answers with an AssignNodeID message with its peer id. AssignNodeID, - // Decentralized + /* Decentralized */ + // Message forwarded by the server from a client to another client + // to establish a peer-to-peer (WebRTC) connection SignalForPeer, PeerIsReady, PeersForRound, - Payload, // Federated SendPayload, - ReceiveServerMetadata, ReceiveServerPayload, } diff --git a/discojs/src/client/utils.ts b/discojs/src/client/utils.ts index 9702fe7ad..21b6afb3c 100644 --- a/discojs/src/client/utils.ts +++ b/discojs/src/client/utils.ts @@ -1,8 +1,28 @@ +import type { Task } from '../index.js' +import { client as clients, type aggregator } from '../index.js' + // Time to wait for the others in milliseconds. -export const MAX_WAIT_PER_ROUND = 15_000 +const MAX_WAIT_PER_ROUND = 15_000 export async function timeout (ms = MAX_WAIT_PER_ROUND, errorMsg: string = 'timeout'): Promise { return await new Promise((_, reject) => { setTimeout(() => { reject(new Error(errorMsg)) }, ms) }) } + +export function getClient(trainingScheme: Required, + serverURL: URL, task: Task, aggregator: aggregator.Aggregator): clients.Client { + + switch (trainingScheme) { + case 'decentralized': + return new clients.decentralized.DecentralizedClient(serverURL, task, aggregator) + case 'federated': + return new clients.federated.FederatedClient(serverURL, task, aggregator) + case 'local': + return new clients.Local(serverURL, task, aggregator) + default: { + const _: never = trainingScheme + throw new Error('should never happen') + } + } +} diff --git a/discojs/src/default_tasks/mnist.ts b/discojs/src/default_tasks/mnist.ts index 8fbd46bf1..5c0ccbf51 100644 --- a/discojs/src/default_tasks/mnist.ts +++ b/discojs/src/default_tasks/mnist.ts @@ -23,13 +23,14 @@ export const mnist: TaskProvider = { trainingInformation: { modelID: 'mnist-model', epochs: 20, - roundDuration: 10, + roundDuration: 2, validationSplit: 0.2, - batchSize: 30, + batchSize: 64, dataType: 'image', IMAGE_H: 28, IMAGE_W: 28, - preprocessingFunctions: [data.ImagePreprocessing.Normalize], + // Images should already be at the right size but resizing just in case + preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize], LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], scheme: 'decentralized', noiseScale: undefined, @@ -42,31 +43,30 @@ export const mnist: TaskProvider = { } }, - getModel (): Promise { + getModel(): Promise { + // Architecture from the PyTorch MNIST example (I made it slightly smaller, 650kB instead of 5MB) + // https://github.com/pytorch/examples/blob/main/mnist/main.py const model = tf.sequential() model.add( tf.layers.conv2d({ inputShape: [28, 28, 3], - kernelSize: 3, - filters: 16, - activation: 'relu' + kernelSize: 5, + filters: 8, + activation: 'relu', }) ) + model.add(tf.layers.conv2d({ kernelSize: 5, filters: 16, activation: 'relu' })) model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })) - model.add( - tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }) - ) - model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })) - model.add( - tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }) - ) - model.add(tf.layers.flatten({})) - model.add(tf.layers.dense({ units: 64, activation: 'relu' })) + model.add(tf.layers.dropout({ rate: 0.25 })) + + model.add(tf.layers.flatten()) + model.add(tf.layers.dense({ units: 32, activation: 'relu' })) + model.add(tf.layers.dropout({rate:0.25})) model.add(tf.layers.dense({ units: 10, activation: 'softmax' })) model.compile({ - optimizer: 'rmsprop', + optimizer: 'adam', loss: 'categoricalCrossentropy', metrics: ['accuracy'] }) diff --git a/discojs/src/task/task_handler.ts b/discojs/src/task/task_handler.ts index 1fe8e2981..79da2be56 100644 --- a/discojs/src/task/task_handler.ts +++ b/discojs/src/task/task_handler.ts @@ -28,8 +28,15 @@ export async function fetchTasks (url: URL): Promise> { const response = await axios.get(new URL(TASK_ENDPOINT, url).href) const tasks: unknown = response.data - if (!(Array.isArray(tasks) && tasks.every(isTask))) { - throw new Error('invalid tasks response') + if (!Array.isArray(tasks)) { + throw new Error('Expected to receive an array of Tasks when fetching tasks') + } else if (!tasks.every(isTask)) { + for (const task of tasks) { + if (!isTask(task)) { + console.error("task has invalid format:", task) + } + } + throw new Error('invalid tasks response, the task object received is not well formatted') } return Map(tasks.map((t) => [t.id, t])) diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index 80adc2d04..8618665ac 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -1,4 +1,3 @@ -import type { AggregatorChoice } from '../aggregator/get.js' import type { Preprocessing } from '../dataset/data/preprocessing/index.js' import { PreTrainedTokenizer } from '@xenova/transformers'; @@ -7,8 +6,8 @@ export interface TrainingInformation { modelID: string // epochs: number of epochs to run training for epochs: number - // roundDuration: number of batches between each weight sharing round, e.g. if 3 then after every - // 3 batches we share weights (in the distributed setting). + // roundDuration: number of epochs between each weight sharing round. + // e.g.if 3 then weights are shared every 3 epochs (in the distributed setting). roundDuration: number // validationSplit: fraction of data to keep for validation, note this only works for image data validationSplit: number @@ -50,7 +49,7 @@ export interface TrainingInformation { minimumReadyPeers?: number // aggregator: aggregator to be used by the server for federated learning, or by the peers for decentralized learning // default is 'average', other options include for instance 'bandit' - aggregator?: AggregatorChoice + aggregator?: 'mean' | 'secure' // TODO: never used // tokenizer (string | PreTrainedTokenizer). This field should be initialized with the name of a Transformers.js pre-trained tokenizer, e.g., 'Xenova/gpt2'. // When the tokenizer is first called, the actual object will be initialized and loaded into this field for the subsequent tokenizations. tokenizer?: string | PreTrainedTokenizer @@ -109,7 +108,7 @@ export function isTrainingInformation (raw: unknown): raw is TrainingInformation typeof validationSplit !== 'number' || (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) || (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') || - (aggregator !== undefined && typeof aggregator !== 'number') || + (aggregator !== undefined && typeof aggregator !== 'string') || (clippingRadius !== undefined && typeof clippingRadius !== 'number') || (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') || (maxShareValue !== undefined && typeof maxShareValue !== 'number') || @@ -125,6 +124,14 @@ export function isTrainingInformation (raw: unknown): raw is TrainingInformation return false } + if (aggregator !== undefined) { + switch (aggregator) { + case 'mean': break + case 'secure': break + default: return false + } + } + switch (dataType) { case 'image': break case 'tabular': break diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 7e2f42e99..7cc2377e5 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -2,8 +2,7 @@ import { List } from 'immutable' import { async_iterator, BatchLogs, data, EpochLogs, Logger, Memory, Task, TrainingInformation } from '../index.js' import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js' -import type { Aggregator } from '../aggregator/index.js' -import { MeanAggregator } from '../aggregator/mean.js' +import {getAggregator, type Aggregator } from '../aggregator/index.js' import { enumerate, split } from '../utils/async_iterator.js' import type { RoundLogs, Trainer } from './trainer/trainer.js' @@ -28,41 +27,28 @@ export class Disco { public readonly logger: Logger public readonly memory: Memory private readonly client: clients.Client - private readonly trainer: Promise + private readonly trainerPromise: Promise constructor ( task: Task, options: DiscoOptions ) { + // Fill undefined options with default values if (options.scheme === undefined) { options.scheme = task.trainingInformation.scheme } - if (options.aggregator === undefined) { - options.aggregator = new MeanAggregator() - } if (options.client === undefined) { if (options.url === undefined) { throw new Error('could not determine client from given parameters') } + if (options.aggregator === undefined) { + options.aggregator = getAggregator(task, { scheme: options.scheme }) + } if (typeof options.url === 'string') { options.url = new URL(options.url) } - switch (options.scheme) { - case 'federated': - options.client = new clients.federated.FederatedClient(options.url, task, options.aggregator) - break - case 'decentralized': - options.client = new clients.decentralized.DecentralizedClient(options.url, task, options.aggregator) - break - case 'local': - options.client = new clients.Local(options.url, task, options.aggregator) - break - default: { - const _: never = options.scheme - throw new Error('should never happen') - } - } + options.client = clients.getClient(options.scheme, options.url, task, options.aggregator) } if (options.logger === undefined) { options.logger = new ConsoleLogger() @@ -80,7 +66,7 @@ export class Disco { this.logger = options.logger const trainerBuilder = new TrainerBuilder(this.memory, this.task) - this.trainer = trainerBuilder.build(this.client, options.scheme !== 'local') + this.trainerPromise = trainerBuilder.build(this.client, options.scheme !== 'local') } /** Train on dataset, yielding logs of every round. */ @@ -140,7 +126,7 @@ export class Disco { const validationData = dataTuple.validation?.preprocess().batch() ?? trainData; await this.client.connect(); - const trainer = await this.trainer; + const trainer = await this.trainerPromise; for await (const [round, epochs] of enumerate( trainer.fitModel(trainData.dataset, validationData.dataset), @@ -172,7 +158,7 @@ export class Disco { return { epochs: epochsLogs, - participants: this.client.nodes.size + 1, // add ourself + participants: this.client.nbOfParticipants, // already includes ourselves }; }.bind(this)(); } @@ -184,7 +170,7 @@ export class Disco { * Stops the ongoing training instance without disconnecting the client. */ async pause (): Promise { - const trainer = await this.trainer + const trainer = await this.trainerPromise await trainer.stopTraining() } diff --git a/discojs/src/training/trainer/trainer.ts b/discojs/src/training/trainer/trainer.ts index 79ae937a1..8ae8178e4 100644 --- a/discojs/src/training/trainer/trainer.ts +++ b/discojs/src/training/trainer/trainer.ts @@ -34,7 +34,7 @@ export abstract class Trainer { this.#epochs = task.trainingInformation.epochs; if (!Number.isInteger(this.#epochs / this.#roundDuration)) - throw new Error(`round duration doesn't divide epochs`); + throw new Error(`round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`); } protected abstract onRoundBegin(round: number): Promise; diff --git a/discojs/src/types.ts b/discojs/src/types.ts index 7cf8cb020..783f8e344 100644 --- a/discojs/src/types.ts +++ b/discojs/src/types.ts @@ -6,9 +6,6 @@ import type { NodeID } from './client/index.js' // Filesystem reference export type Path = string -export type MetadataKey = string -export type MetadataValue = string - export type Features = number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] export type Contributions = Map diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index f4cba4d09..1e10a3cd0 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -185,7 +185,7 @@ git push -u origin 202-train-bug-nacho - TypeScript files should be written in snake_case, lowercase words separated by underscores, e.g. `event_connection.ts` - Vue.js files should be written in PascalCase (capitalized words including the first), e.g. `DatasetInput.vue` -- Classes and types should also be written in PascalCase. For example class `AsyncInformant` and type `MetadataValue` +- Classes, interfaces and types should also be written in PascalCase. For example class `MeanAggregator` and interface `EventConnection` - Functions and variable names should be written in camelCase, starting with a lowercase letter: function `isWithinRoundCutoff` and variable `roundCutoff` #### Docstring @@ -201,8 +201,7 @@ Test the newly implemented features locally by following instructions in the [Co Once you have added a minimum number of content to your branch, you can create a [draft PR](https://github.blog/2019-02-14-introducing-draft-pull-requests/). Create a pull request to merge your branch (e.g., `202-train-bug-nacho`) into the `develop` branch. `develop` should always be functional and up to date with new working features. It is the equivalent of the `main`or `master` branch in DISCO. It is important to give a good description to your PR as this makes it easier for other people to go through it. -> [!TIP] -> [This PR](https://github.com/epfml/disco/pull/176) is a good example. +> [!TIP] > [This PR](https://github.com/epfml/disco/pull/176) is a good example. ### 5. Before requesting a review diff --git a/docs/examples/wikitext.ts b/docs/examples/wikitext.ts index 0f3ae4145..9471ea2f5 100644 --- a/docs/examples/wikitext.ts +++ b/docs/examples/wikitext.ts @@ -26,7 +26,7 @@ async function main(): Promise { const dataset = await loadWikitextData(task) // Initialize a Disco instance and start training a language model - const aggregator = new aggregators.MeanAggregator() + const aggregator = aggregators.getAggregator(task) const client = new clients.federated.FederatedClient(url, task, aggregator) const disco = new Disco(task, { scheme: 'federated', client, aggregator }) await disco.trainFully(dataset); diff --git a/server/src/router/decentralized/server.ts b/server/src/router/decentralized/server.ts index dba00779e..e3a459fe7 100644 --- a/server/src/router/decentralized/server.ts +++ b/server/src/router/decentralized/server.ts @@ -62,9 +62,11 @@ export class Decentralized extends Server { } switch (msg.type) { - // A new peer joins the network + // A new peer joins the network for a task case MessageTypes.ClientConnected: { this.connections = this.connections.set(peerId, ws) + + // Answer with client id in an AssignNodeID message const msg: AssignNodeID = { type: MessageTypes.AssignNodeID, id: peerId @@ -79,17 +81,8 @@ export class Decentralized extends Server { ws.send(msgpack.encode(msg), { binary: true }) break } - - // Forwards a peer's message to another destination peer - case MessageTypes.SignalForPeer: { - const forward: messages.SignalForPeer = { - type: MessageTypes.SignalForPeer, - peer: peerId, - signal: msg.signal - } - this.connections.get(msg.peer)?.send(msgpack.encode(forward)) - break - } + // Send by peers at the beginning of each training round to get the list + // of active peers for this round. case MessageTypes.PeerIsReady: { const peers = this.readyNodes.get(task.id)?.add(peerId) if (peers === undefined) { @@ -119,6 +112,18 @@ export class Decentralized extends Server { } break } + // Forwards a peer's message to another destination peer + // Used to exchange peer's information and establish a direct + // WebRTC connection between them + case MessageTypes.SignalForPeer: { + const forward: messages.SignalForPeer = { + type: MessageTypes.SignalForPeer, + peer: peerId, + signal: msg.signal + } + this.connections.get(msg.peer)?.send(msgpack.encode(forward)) + break + } default: { const _: never = msg throw new Error('should never happen') diff --git a/server/src/router/federated/server.ts b/server/src/router/federated/server.ts index b7ea8adfb..cabbd1d78 100644 --- a/server/src/router/federated/server.ts +++ b/server/src/router/federated/server.ts @@ -8,8 +8,6 @@ import type { Task, TaskID, WeightsContainer, - MetadataKey, - MetadataValue } from '@epfml/discojs' import { client, @@ -55,11 +53,7 @@ export class Federated extends Server { * Training informants for each hosted task. */ private informants = Map>() - /** - * Contains metadata used for training by clients for a given task and round. - * Stored by task id, round number, node id and metadata key. - */ - private metadataMap = Map>>>() + // TODO use real log system /** * Logs of successful requests made to the server. @@ -108,8 +102,9 @@ export class Federated extends Server { void this.storeAggregationResult(task, aggregator) } - protected initTask (task: TaskID, model: Model): void { - const aggregator = new aggregators.MeanAggregator(model) + protected initTask(task: TaskID, model: Model): void { + // The server waits for 100% of the nodes to send their contributions before aggregating the updates + const aggregator = new aggregators.MeanAggregator(model, undefined, 1, 'relative') this.aggregators = this.aggregators.set(task, aggregator) this.informants = this.informants.set(task, new AsyncInformant(aggregator)) @@ -141,7 +136,7 @@ export class Federated extends Server { // Wait for aggregation result to resolve with timeout, giving the network a time window // to contribute to the model - void Promise.race([promisedResult, client.utils.timeout()]) + void Promise.race([promisedResult, client.timeout()]) //TODO: it doesn't make sense that the server is using the client utils' timeout .then((result) => // Reply with round - 1 because the round number should match the round at which the client sent its weights // After the server aggregated the weights it also incremented the round so the server replies with round - 1 @@ -152,7 +147,8 @@ export class Federated extends Server { const msg: messages.ReceiveServerPayload = { type: MessageTypes.ReceiveServerPayload, round, - payload: serialized + payload: serialized, + nbOfParticipants: aggregator.nodes.size } ws.send(msgpack.encode(msg)) }) @@ -192,45 +188,12 @@ export class Federated extends Server { const { payload, round } = msg const weights = serialization.weights.decode(payload) - this.createPromiseForWeights(task.id, aggregator, ws) + this.createPromiseForWeights(task.id, aggregator, ws) - // TODO support multiple communication round + // TODO support multiple communication round if (!aggregator.add(clientId, weights, round, 0)) { console.info(`dropped contribution from client ${clientId} for round ${round}`) return // TODO what to answer? - } - } else if (msg.type === MessageTypes.ReceiveServerMetadata) { - const { key, round } = msg - - const taskMetadata = this.metadataMap.get(task.id) - - if (!Number.isNaN(round) && round >= 0 && taskMetadata !== undefined) { - // Find the most recent entry round-wise for the given task (upper bounded - // by the given round). Allows for sporadic entries in the metadata map. - const latestRound = taskMetadata.keySeq().max() ?? round - - // Fetch the required metadata from the general metadata structure stored - // server-side and construct the queried metadata's map accordingly. This - // essentially creates a "ID -> metadata" single-layer map. - const queriedMetadataMap = Map( - taskMetadata - .get(latestRound, Map>()) - .filter((entries) => entries.has(key)) - .mapEntries(([id, entries]) => [id, entries.get(key)]) - ) - - this.logsAppend(task.id, clientId, MessageTypes.ReceiveServerMetadata, round) - - const msg: messages.ReceiveServerMetadata = { - type: MessageTypes.ReceiveServerMetadata, - taskId: task.id, - nodeId: clientId, - key, - round, - metadataMap: Array.from(queriedMetadataMap) - } - - ws.send(msgpack.encode(msg)) } } }) diff --git a/server/tests/client/federated.spec.ts b/server/tests/client/federated.spec.ts index 6ae3962d3..98fc8247e 100644 --- a/server/tests/client/federated.spec.ts +++ b/server/tests/client/federated.spec.ts @@ -26,7 +26,7 @@ describe("federated client", function () { const client = new clients.federated.FederatedClient( url, TASK, - new aggregators.MeanAggregator(), + aggregators.getAggregator(TASK), ); await client.connect(); await client.disconnect(); @@ -52,7 +52,7 @@ describe("federated client", function () { tensorBackend: 'tfjs' }, }, - new aggregators.MeanAggregator(), + aggregators.getAggregator(TASK), ); try { diff --git a/server/tests/e2e/decentralized.spec.ts b/server/tests/e2e/decentralized.spec.ts index 5fb45050a..c0ef7c5a6 100644 --- a/server/tests/e2e/decentralized.spec.ts +++ b/server/tests/e2e/decentralized.spec.ts @@ -2,7 +2,6 @@ import type { Server } from 'node:http' import { List } from 'immutable' import { assert } from 'chai' -import type { Task } from '@epfml/discojs' import { aggregator as aggregators, client as clients, WeightsContainer, defaultTasks, aggregation } from '@epfml/discojs' @@ -56,8 +55,6 @@ class MockSecureAggregator extends aggregators.SecureAggregator { } } -type MockAggregator = MockMeanAggregator | MockSecureAggregator - describe('end-to-end decentralized', function () { const epsilon = 1e-4 this.timeout(30_000) @@ -73,13 +70,15 @@ describe('end-to-end decentralized', function () { * the client will implement secure aggregation. If it is false, it will be a clear text client. */ async function simulateClient ( - Aggregator: new (task: Task) => MockAggregator, + aggregatorType: 'mean' | 'secure', input: number[], rounds: number ): Promise<[WeightsContainer, clients.Client]> { const task = defaultTasks.cifar10.getTask() const inputWeights = WeightsContainer.of(input) - const aggregator = new Aggregator(task) + const aggregator = aggregatorType == 'mean' ? + new MockMeanAggregator(undefined, undefined, 1, 'relative') + : new MockSecureAggregator() const client = new clients.decentralized.DecentralizedClient(url, task, aggregator) @@ -104,7 +103,7 @@ describe('end-to-end decentralized', function () { * The clients have model dimension of 4 model updates to share, which can be seen as their input parameter in makeClient. */ async function reachConsensus ( - Aggregator: new () => MockAggregator, + aggregatorType: 'mean' | 'secure', rounds = 1 ): Promise { // Expect the clients to reach the mean consensus, for both the mean and secure aggregators @@ -114,7 +113,7 @@ describe('end-to-end decentralized', function () { [0.002, 5, 30, 11], [0.003, 13, 11, 12] ) - const actual = await Promise.all(contributions.map(async (w) => await simulateClient(Aggregator, w, rounds)).toArray()) + const actual = await Promise.all(contributions.map(async (w) => await simulateClient(aggregatorType, w, rounds)).toArray()) const consensuses = await Promise.all(actual.map(async ([consensus, client]) => { // Disconnect clients once they reached consensus await client.disconnect() @@ -124,18 +123,18 @@ describe('end-to-end decentralized', function () { } it('single round of cifar 10 with three mean aggregators yields consensus', async () => { - await reachConsensus(MockMeanAggregator) + await reachConsensus('mean') }) it('several rounds of cifar 10 with three mean aggregators yields consensus', async () => { - await reachConsensus(MockMeanAggregator, 3) + await reachConsensus('mean', 3) }) it('single round of cifar 10 with three secure aggregators yields consensus', async () => { - await reachConsensus(MockSecureAggregator) + await reachConsensus('secure') }) it('several rounds of cifar 10 with three secure aggregators yields consensus', async () => { - await reachConsensus(MockSecureAggregator, 3) + await reachConsensus('secure', 3) }) }) diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 774343184..9bfe43801 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -38,13 +38,14 @@ describe("end-to-end federated", function () { const files = (await fs.readdir(dir)).map((file) => path.join(dir, file)) const labels = Repeat('cat', 24).toArray() // TODO read labels in csv + const trainingScheme = 'federated' const cifar10Task = defaultTasks.cifar10.getTask() - + cifar10Task.trainingInformation.scheme = trainingScheme const data = await new NodeImageLoader(cifar10Task).loadAll(files, { labels, shuffle: false }) - const aggregator = new aggregators.MeanAggregator() + const aggregator = aggregators.getAggregator(cifar10Task, {scheme: trainingScheme}) const client = new clients.federated.FederatedClient(url, cifar10Task, aggregator) - const disco = new Disco(cifar10Task, { scheme: 'federated', client }) + const disco = new Disco(cifar10Task, { scheme: trainingScheme, client }) await disco.trainFully(data); await disco.close() @@ -58,7 +59,9 @@ describe("end-to-end federated", function () { async function titanicUser (): Promise { const files = [DATASET_DIR + 'titanic_train.csv'] + const trainingScheme = 'federated' const titanicTask = defaultTasks.titanic.getTask() + titanicTask.trainingInformation.scheme = trainingScheme titanicTask.trainingInformation.epochs = titanicTask.trainingInformation.roundDuration = 5 const data = await (new NodeTabularLoader(titanicTask, ',').loadAll( files, @@ -68,10 +71,9 @@ describe("end-to-end federated", function () { shuffle: false } )) - - const aggregator = new aggregators.MeanAggregator() + const aggregator = aggregators.getAggregator(titanicTask, {scheme: trainingScheme}) const client = new clients.federated.FederatedClient(url, titanicTask, aggregator) - const disco = new Disco(titanicTask, { scheme: 'federated', client, aggregator }) + const disco = new Disco(titanicTask, { scheme: trainingScheme, client, aggregator }) const logs = List(await arrayFromAsync(disco.trainByRound(data))); await disco.close() @@ -88,8 +90,10 @@ describe("end-to-end federated", function () { return aggregator.model.weights } - async function wikitextUser (): Promise { + async function wikitextUser(): Promise { + const trainingScheme = 'federated' const task = defaultTasks.wikitext.getTask() + task.trainingInformation.scheme = trainingScheme task.trainingInformation.epochs = 2 const loader = new NodeTextLoader(task) const dataSplit: data.DataSplit = { @@ -97,9 +101,9 @@ describe("end-to-end federated", function () { validation: await data.TextData.init(await loader.load(DATASET_DIR + 'wikitext/wiki.valid.tokens'), task) } - const aggregator = new aggregators.MeanAggregator() + const aggregator = aggregators.getAggregator(task, {scheme: trainingScheme}) const client = new clients.federated.FederatedClient(url, task, aggregator) - const disco = new Disco(task, { scheme: 'federated', client, aggregator }) + const disco = new Disco(task, { scheme: trainingScheme, client, aggregator }) const logs = List(await arrayFromAsync(disco.trainByRound(dataSplit))); await disco.close() @@ -120,16 +124,19 @@ describe("end-to-end federated", function () { const positiveLabels = files[0].map(_ => 'COVID-Positive') const negativeLabels = files[1].map(_ => 'COVID-Negative') const labels = positiveLabels.concat(negativeLabels) + + const trainingScheme = 'federated' const lusCovidTask = defaultTasks.lusCovid.getTask() + lusCovidTask.trainingInformation.scheme = trainingScheme lusCovidTask.trainingInformation.epochs = 16 lusCovidTask.trainingInformation.roundDuration = 4 const data = await new NodeImageLoader(lusCovidTask) .loadAll(files.flat(), { labels, channels: 3 }) - const aggregator = new aggregators.MeanAggregator() + const aggregator = aggregators.getAggregator(lusCovidTask, {scheme: trainingScheme}) const client = new clients.federated.FederatedClient(url, lusCovidTask, aggregator) - const disco = new Disco(lusCovidTask, { scheme: 'federated', client }) + const disco = new Disco(lusCovidTask, { scheme: trainingScheme, client }) const logs = List(await arrayFromAsync(disco.trainByRound(data))); await disco.close() diff --git a/server/tests/validator.spec.ts b/server/tests/validator.spec.ts index a613ea206..e66db66c6 100644 --- a/server/tests/validator.spec.ts +++ b/server/tests/validator.spec.ts @@ -4,7 +4,7 @@ import type { Server } from 'node:http' import { Validator, ConsoleLogger, EmptyMemory, client as clients, - aggregator, defaultTasks, data + aggregator as aggregators, defaultTasks, data } from '@epfml/discojs' import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node' import { startServer } from '../src/index.js' @@ -30,17 +30,17 @@ describe('validator', function () { const adultLabels = files[1].map(_ => 'adult') const labels = childLabels.concat(adultLabels) - const simplefaceTask = defaultTasks.simpleFace.getTask() + const simpleFaceTask = defaultTasks.simpleFace.getTask() - const data = (await new NodeImageLoader(simplefaceTask) + const data = (await new NodeImageLoader(simpleFaceTask) .loadAll(files.flat(), { labels, channels: undefined })).train // Init a validator instance - const meanAggregator = new aggregator.MeanAggregator() - const client = new clients.Local(url, simplefaceTask, meanAggregator) + const meanAggregator = aggregators.getAggregator(simpleFaceTask, {scheme: 'local'}) + const client = new clients.Local(url, simpleFaceTask, meanAggregator) meanAggregator.setModel(await client.getLatestModel()) const validator = new Validator( - simplefaceTask, + simpleFaceTask, new ConsoleLogger(), new EmptyMemory(), undefined, @@ -71,7 +71,7 @@ describe('validator', function () { labels: titanicTask.trainingInformation.outputColumns, shuffle: false })).train - const meanAggregator = new aggregator.MeanAggregator() + const meanAggregator = aggregators.getAggregator(titanicTask, {scheme: 'local'}) const client = new clients.Local(url, titanicTask, meanAggregator) meanAggregator.setModel(await client.getLatestModel()) const validator = new Validator(titanicTask, new ConsoleLogger(), new EmptyMemory(), undefined, client) @@ -109,7 +109,7 @@ describe('validator', function () { .loadAll(files.flat(), { labels, channels: 3 })).train // Initialize a validator instance - const meanAggregator = new aggregator.MeanAggregator() + const meanAggregator = aggregators.getAggregator(lusCovidTask, {scheme: 'local'}) const client = new clients.Local(url, lusCovidTask, meanAggregator) meanAggregator.setModel(await client.getLatestModel()) diff --git a/webapp/src/clients.ts b/webapp/src/clients.ts deleted file mode 100644 index e32cb5d45..000000000 --- a/webapp/src/clients.ts +++ /dev/null @@ -1,21 +0,0 @@ -import type { Task } from '@epfml/discojs' -import { client as clients, aggregator as aggregators } from '@epfml/discojs' - -import { CONFIG } from './config' - -export function getClient (trainingScheme: Required, task: Task): clients.Client { - const aggregator = aggregators.getAggregator(task) - - switch (trainingScheme) { - case 'decentralized': - return new clients.decentralized.DecentralizedClient(CONFIG.serverUrl, task, aggregator) - case 'federated': - return new clients.federated.FederatedClient(CONFIG.serverUrl, task, aggregator) - case 'local': - return new clients.Local(CONFIG.serverUrl, task, aggregator) - default: { - const _: never = trainingScheme - throw new Error('should never happen') - } - } -} diff --git a/webapp/src/components/data/dataset_input/FileSelection.vue b/webapp/src/components/data/dataset_input/FileSelection.vue index 6143a7a17..baf13d4b0 100644 --- a/webapp/src/components/data/dataset_input/FileSelection.vue +++ b/webapp/src/components/data/dataset_input/FileSelection.vue @@ -75,7 +75,10 @@ -
+
{ isTraining.value = true isTrainingAlone.value = !distributed - console.log(isTraining.value, isTrainingAlone.value) // Reset training information before starting a new training trainingGenerator.value = undefined roundsLogs.value = List() @@ -162,7 +161,8 @@ async function startTraining(distributed: boolean): Promise { toaster.info("Model training started"); const scheme = distributed ? props.task.trainingInformation.scheme : "local"; - const client = getClient(scheme, props.task) + const aggregator = aggregators.getAggregator(props.task, { scheme }) // overwrite the training information scheme + const client = clients.getClient(scheme, CONFIG.serverUrl, props.task, aggregator) const disco = new Disco(props.task, { logger: { diff --git a/webapp/src/store/tasks.ts b/webapp/src/store/tasks.ts index f5c23fe53..5af68f513 100644 --- a/webapp/src/store/tasks.ts +++ b/webapp/src/store/tasks.ts @@ -9,8 +9,6 @@ import { useToaster } from '@/composables/toaster' import { CONFIG } from '@/config' import { useTrainingStore } from './training' -const TASKS_TO_FILTER_OUT = ['simple_face', 'cifar10'] - export const useTasksStore = defineStore('tasks', () => { const trainingStore = useTrainingStore() @@ -25,6 +23,8 @@ export const useTasksStore = defineStore('tasks', () => { tasks.value = tasks.value.set(task.id, task) } + const TASKS_TO_FILTER_OUT = ['simple_face', 'cifar10'] + async function initTasks (): Promise { try { const tasks = (await fetchTasks(CONFIG.serverUrl)).filter((t: Task) => !TASKS_TO_FILTER_OUT.includes(t.id))