From 8a2e846cbc11d57593777396a517185b45134a20 Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Thu, 1 Feb 2024 17:03:41 +0000 Subject: [PATCH 1/9] Fix #615, titanic training having NaN loss --- .../preprocessing/tabular_preprocessing.ts | 21 +++++++++++- .../src/dataset/data/tabular_data.ts | 2 +- .../discojs-core/src/default_tasks/titanic.ts | 5 +-- docs/node_example/data.ts | 13 +++++++- docs/node_example/example.ts | 18 ++++------ docs/node_example/start_server.ts | 33 ------------------- docs/node_example/tsconfig.json | 2 +- server/src/router/federated/server.ts | 3 ++ 8 files changed, 46 insertions(+), 51 deletions(-) delete mode 100644 docs/node_example/start_server.ts diff --git a/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts b/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts index 00cbbe976..34e4ddd67 100644 --- a/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts +++ b/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts @@ -1,3 +1,4 @@ +import { Task, tf } from '../../..' import { List } from 'immutable' import { PreprocessingFunction } from './base' @@ -9,7 +10,25 @@ export enum TabularPreprocessing { Normalize } +interface TabularEntry extends tf.TensorContainerObject { + xs: number[] + ys: tf.Tensor1D | number | undefined +} + +const sanitize: PreprocessingFunction = { + type: TabularPreprocessing.Sanitize, + apply: (entry: tf.TensorContainer, task: Task): tf.TensorContainer => { + const { xs, ys } = entry as TabularEntry + return { + xs: xs.map(i => i === undefined ? 0 : i), + ys: ys + } + } +} + /** * Available tabular preprocessing functions. */ -export const AVAILABLE_PREPROCESSING = List() +export const AVAILABLE_PREPROCESSING = List([ + sanitize] +).sortBy((e) => e.type) diff --git a/discojs/discojs-core/src/dataset/data/tabular_data.ts b/discojs/discojs-core/src/dataset/data/tabular_data.ts index 5b2416f18..d205f8f5f 100644 --- a/discojs/discojs-core/src/dataset/data/tabular_data.ts +++ b/discojs/discojs-core/src/dataset/data/tabular_data.ts @@ -21,7 +21,7 @@ export class TabularData extends Data { try { await dataset.iterator() } catch (e) { - throw new Error('Data input format is not compatible with the chosen task') + throw new Error('Data input format is not compatible with the chosen task'+ e) } return new TabularData(dataset, task, size) diff --git a/discojs/discojs-core/src/default_tasks/titanic.ts b/discojs/discojs-core/src/default_tasks/titanic.ts index 5c7214266..c51be7490 100644 --- a/discojs/discojs-core/src/default_tasks/titanic.ts +++ b/discojs/discojs-core/src/default_tasks/titanic.ts @@ -1,4 +1,4 @@ -import { tf, Task, TaskProvider } from '..' +import { tf, Task, TaskProvider, data } from '..' export const titanic: TaskProvider = { getTask (): Task { @@ -49,7 +49,8 @@ export const titanic: TaskProvider = { roundDuration: 10, validationSplit: 0.2, batchSize: 30, - preprocessingFunctions: [], + preprocessingFunctions: [data.TabularPreprocessing.Sanitize], + learningRate: 0.001, modelCompileData: { optimizer: 'sgd', loss: 'binaryCrossentropy', diff --git a/docs/node_example/data.ts b/docs/node_example/data.ts index 7e4f6f171..27136720c 100644 --- a/docs/node_example/data.ts +++ b/docs/node_example/data.ts @@ -1,7 +1,7 @@ import fs from 'fs' import Rand from 'rand-seed' -import { node, data, Task } from '@epfml/discojs-node' +import { node, data, Task, defaultTasks } from '@epfml/discojs-node' const rand = new Rand('1234') @@ -45,3 +45,14 @@ export async function loadData (task: Task): Promise { return await new node.data.NodeImageLoader(task).loadAll(files, { labels: labels }) } + +export async function loadTitanicData (task:Task): Promise { + const files = ['../../example_training_data/titanic_train.csv'] + const titanicTask = defaultTasks.titanic.getTask() + return await new node.data.NodeTabularLoader(task, ',').loadAll(files, { + features: titanicTask.trainingInformation.inputColumns, + labels: titanicTask.trainingInformation.outputColumns, + shuffle: false + }) +} \ No newline at end of file diff --git a/docs/node_example/example.ts b/docs/node_example/example.ts index 22258a299..c38bee46d 100644 --- a/docs/node_example/example.ts +++ b/docs/node_example/example.ts @@ -1,7 +1,5 @@ import { data, Disco, fetchTasks, Task } from '@epfml/discojs-node' - -import { startServer } from './start_server' -import { loadData } from './data' +import { loadTitanicData } from './data' /** * Example of discojs API, we load data, build the appropriate loggers, the disco object @@ -18,24 +16,20 @@ async function runUser (url: URL, task: Task, dataset: data.DataSplit): Promise< async function main (): Promise { - const [server, serverUrl] = await startServer() + // Start a server instance before running this example + const serverUrl = new URL('http://localhost:8080/') + const tasks = await fetchTasks(serverUrl) // Choose your task to train - const task = tasks.get('simple_face') as Task + const task = tasks.get('titanic') as Task - const dataset = await loadData(task) + const dataset = await loadTitanicData(task) // Add more users to the list to simulate more clients await Promise.all([ - runUser(serverUrl, task, dataset), runUser(serverUrl, task, dataset) ]) - - await new Promise((resolve, reject) => { - server.once('close', resolve) - server.close(reject) - }) } main().catch(console.error) diff --git a/docs/node_example/start_server.ts b/docs/node_example/start_server.ts deleted file mode 100644 index f8bea0f63..000000000 --- a/docs/node_example/start_server.ts +++ /dev/null @@ -1,33 +0,0 @@ -import http from 'node:http' - -import { Disco } from '@epfml/disco-server' - -export async function startServer (): Promise<[http.Server, URL]> { - const disco = new Disco() - await disco.addDefaultTasks() - - const server = disco.serve(8000) - await new Promise((resolve, reject) => { - server.once('listening', resolve) - server.once('error', reject) - server.on('error', console.error) - }) - - let addr: string - const rawAddr = server.address() - if (rawAddr === null) { - throw new Error('unable to get server address') - } else if (typeof rawAddr === 'string') { - addr = rawAddr - } else if (typeof rawAddr === 'object') { - if (rawAddr.family === '4') { - addr = `${rawAddr.address}:${rawAddr.port}` - } else { - addr = `[${rawAddr.address}]:${rawAddr.port}` - } - } else { - throw new Error('unable to get address to server') - } - - return [server, new URL('', `http://${addr}`)] -} \ No newline at end of file diff --git a/docs/node_example/tsconfig.json b/docs/node_example/tsconfig.json index 1e8eb3ae8..74dbc77ec 100644 --- a/docs/node_example/tsconfig.json +++ b/docs/node_example/tsconfig.json @@ -14,7 +14,7 @@ "declaration": true, - "typeRoots": ["node_modules/@types", "discojs-core/types"] + "typeRoots": ["node_modules/@types", "../../discojs/discojs-core/types"] }, "include": ["*.ts"], "exclude": ["node_modules"] diff --git a/server/src/router/federated/server.ts b/server/src/router/federated/server.ts index b8a1fc77f..a5bbee6a3 100644 --- a/server/src/router/federated/server.ts +++ b/server/src/router/federated/server.ts @@ -99,6 +99,9 @@ export class Federated extends Server { // Store the result promise somewhere for the server to fetch from, so that it can await // the result on client request. this.results = this.results.set(aggregator.task.taskID, result) + // Set a minimum amount of time to wait in the current round to let clients ask for the latest weights + // This is relevant mostly when there are few clients and rounds are almost instantaneous. + await new Promise(resolve => setTimeout(resolve, 1000)) await result void this.storeAggregationResult(aggregator) } From 89b81a261ae891a07daef10a5662e40ca746548a Mon Sep 17 00:00:00 2001 From: Julien Vignoud <33122365+JulienVig@users.noreply.github.com> Date: Thu, 1 Feb 2024 17:46:51 +0000 Subject: [PATCH 2/9] Fixes #612 Increase test timeout --- server/tests/e2e/federated.spec.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 3cae23333..b4f50e258 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -11,7 +11,7 @@ import { getClient, startServer } from '../utils' const SCHEME = TrainingSchemes.FEDERATED describe('end-to-end federated', function () { - this.timeout(90_000) + this.timeout(120_000) let server: Server beforeEach(async () => { From 3448e03042e0af3badad830bf754b69bf96fae66 Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Mon, 5 Feb 2024 12:13:45 +0100 Subject: [PATCH 3/9] Commented titanic test now passes --- .../src/validation/validator.spec.ts | 51 +++++++++++-------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/discojs/discojs-core/src/validation/validator.spec.ts b/discojs/discojs-core/src/validation/validator.spec.ts index 0e060a97c..0cb65d72a 100644 --- a/discojs/discojs-core/src/validation/validator.spec.ts +++ b/discojs/discojs-core/src/validation/validator.spec.ts @@ -1,7 +1,7 @@ import { assert } from 'chai' import fs from 'fs' -import { Task, node, Validator, ConsoleLogger, EmptyMemory, client as clients, data, aggregator } from '@epfml/discojs-node' +import { Task, node, Validator, ConsoleLogger, EmptyMemory, client as clients, data, aggregator, defaultTasks } from '@epfml/discojs-node' const simplefaceMock = { taskID: 'simple_face', @@ -57,23 +57,34 @@ describe('validator', () => { console.table(validator.confusionMatrix) }).timeout(10_000) - // TODO: fix titanic model (nan accuracy) - // it('works for titanic', async () => { - // const data: Data = await new NodeTabularLoader(titanic.task, ',') - // .loadAll(['../../example_training_data/titanic.csv'], { - // features: titanic.task.trainingInformation?.inputColumns, - // labels: titanic.task.trainingInformation?.outputColumns - // }) - // const validator = new Validator(titanic.task, new ConsoleLogger(), titanic.model()) - // await validator.assess(data) - - // assert( - // validator.visitedSamples() === data.size, - // `expected ${TITANIC_SAMPLES} visited samples but got ${validator.visitedSamples()}` - // ) - // assert( - // validator.accuracy() > 0.5, - // `expected accuracy greater than 0.5 but got ${validator.accuracy()}` - // ) - // }) + it('works for titanic', async () => { + const titanicTask = defaultTasks.titanic.getTask() + const files = ['../../example_training_data/titanic_train.csv'] + const data: data.Data = (await new node.data.NodeTabularLoader(titanicTask, ',').loadAll(files, { + features: titanicTask.trainingInformation.inputColumns, + labels: titanicTask.trainingInformation.outputColumns, + shuffle: false + })).train + const buffer = new aggregator.MeanAggregator(titanicTask) + const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, buffer) + buffer.setModel(await client.getLatestModel()) + const validator = new Validator(titanicTask, + new ConsoleLogger(), + new EmptyMemory(), + undefined, + client) + await validator.assess(data) + // data.size is undefined because tfjs handles dataset lazily + // instead we count the dataset size manually + let size = 0 + await data.dataset.forEachAsync(() => size+=1) + assert( + validator.visitedSamples === size, + `expected ${size} visited samples but got ${validator.visitedSamples}` + ) + assert( + validator.accuracy > 0.5, + `expected accuracy greater than 0.5 but got ${validator.accuracy}` + ) + }).timeout(10_000) }) From a1fe95fa40d5840ebcd7160bd8e05886203615fe Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Tue, 6 Feb 2024 12:17:43 +0100 Subject: [PATCH 4/9] Fix #610 collaborative training failing with timeout --- discojs/discojs-core/src/aggregator/base.ts | 4 +- .../discojs-core/src/client/federated/base.ts | 36 ++- .../src/dataset/data/tabular_data.ts | 2 +- docs/node_example/example.ts | 6 +- server/src/router/federated/server.ts | 226 +++++++++++------- 5 files changed, 155 insertions(+), 119 deletions(-) diff --git a/discojs/discojs-core/src/aggregator/base.ts b/discojs/discojs-core/src/aggregator/base.ts index 7dd2f94e6..ac802431e 100644 --- a/discojs/discojs-core/src/aggregator/base.ts +++ b/discojs/discojs-core/src/aggregator/base.ts @@ -67,7 +67,7 @@ export abstract class Base { */ protected readonly roundCutoff = 0, /** - * The number of communication rounds occuring during any given aggregation round. + * The number of communication rounds occurring during any given aggregation round. */ public readonly communicationRounds = 1 ) { @@ -272,7 +272,7 @@ export abstract class Base { } /** - * The current commnication round. + * The current communication round. */ get communicationRound (): number { return this._communicationRound diff --git a/discojs/discojs-core/src/client/federated/base.ts b/discojs/discojs-core/src/client/federated/base.ts index 43fd36acb..89aa42dee 100644 --- a/discojs/discojs-core/src/client/federated/base.ts +++ b/discojs/discojs-core/src/client/federated/base.ts @@ -92,41 +92,37 @@ export class Base extends Client { /** * Send a message containing our local weight updates to the federated server. + * And waits for the server to reply with the most recent aggregated weights * @param weights The weight updates to send */ - async sendPayload (payload: WeightsContainer): Promise { + private async sendPayloadAndReceiveResult (payload: WeightsContainer): Promise { const msg: messages.SendPayload = { type: type.SendPayload, payload: await serialization.weights.encode(payload), round: this.aggregator.round } this.server.send(msg) + // It is important than the client immediately awaits the server result or it may miss it + return await this.receiveResult() } /** - * Fetches the server's result for its current (most recent) round and add it to our aggregator. + * Waits for the server's result for its current (most recent) round and add it to our aggregator. * Updates the aggregator's round if it's behind the server's. */ - async receiveResult (): Promise { - this.serverRound = undefined - this.serverResult = undefined - - const msg: messages.MessageBase = { - type: type.ReceiveServerPayload - } - this.server.send(msg) - + async receiveResult (): Promise { try { const { payload, round } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload) - this.serverRound = round + const serverRound = round // Store the server result only if it is not stale if (this.aggregator.round <= round) { - this.serverResult = serialization.weights.decode(payload) + const serverResult = serialization.weights.decode(payload) // Update the local round to match the server's - if (this.aggregator.round < this.serverRound) { - this.aggregator.setRound(this.serverRound) + if (this.aggregator.round < serverRound) { + this.aggregator.setRound(serverRound) } + return serverResult } } catch (e) { console.error(e) @@ -226,13 +222,11 @@ export class Base extends Client { throw new Error('local aggregation result was not set') } - // Send our contribution to the server - await this.sendPayload(this.aggregator.makePayloads(weights).first()) - // Fetch the server result - await this.receiveResult() + // Send our local contribution to the server + // and receive the most recent weights as an answer to our contribution + const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first()) - // TODO @s314cy: add communication rounds to federated learning - if (this.serverResult !== undefined && this.aggregator.add(Base.SERVER_NODE_ID, this.serverResult, round, 0)) { + if (serverResult !== undefined && this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) { // Regular case: the server sends us its aggregation result which will serve our // own aggregation result. } else { diff --git a/discojs/discojs-core/src/dataset/data/tabular_data.ts b/discojs/discojs-core/src/dataset/data/tabular_data.ts index d205f8f5f..804e4984b 100644 --- a/discojs/discojs-core/src/dataset/data/tabular_data.ts +++ b/discojs/discojs-core/src/dataset/data/tabular_data.ts @@ -21,7 +21,7 @@ export class TabularData extends Data { try { await dataset.iterator() } catch (e) { - throw new Error('Data input format is not compatible with the chosen task'+ e) + throw new Error(`Data input format is not compatible with the chosen task: ${e}`) } return new TabularData(dataset, task, size) diff --git a/docs/node_example/example.ts b/docs/node_example/example.ts index c38bee46d..244cbfeb1 100644 --- a/docs/node_example/example.ts +++ b/docs/node_example/example.ts @@ -16,7 +16,7 @@ async function runUser (url: URL, task: Task, dataset: data.DataSplit): Promise< async function main (): Promise { - // Start a server instance before running this example + // First have a server instance running before running this script const serverUrl = new URL('http://localhost:8080/') const tasks = await fetchTasks(serverUrl) @@ -28,7 +28,9 @@ async function main (): Promise { // Add more users to the list to simulate more clients await Promise.all([ - runUser(serverUrl, task, dataset) + runUser(serverUrl, task, dataset), + runUser(serverUrl, task, dataset), + runUser(serverUrl, task, dataset), ]) } diff --git a/server/src/router/federated/server.ts b/server/src/router/federated/server.ts index a5bbee6a3..1d40d783e 100644 --- a/server/src/router/federated/server.ts +++ b/server/src/router/federated/server.ts @@ -89,20 +89,24 @@ export class Federated extends Server { } /** - * Loop storing aggregation results, every time an aggregation result promise resolves. - * This happens once per round. + * + * Loop creating an aggregation result promise at each round. + * Because clients contribute to the round asynchronously, a promise is used to let them wait + * until the server has aggregated the weights. This loop creates a promise whenever the previous + * one resolved and awaits until it resolves. The promise is used in createPromiseForWeights. * @param aggregator The aggregation handler */ private async storeAggregationResult (aggregator: aggregators.Aggregator): Promise { - // Renew the aggregation result promise. + // Create a promise on the future aggregated weights const result = aggregator.receiveResult() - // Store the result promise somewhere for the server to fetch from, so that it can await - // the result on client request. + // Store the promise such that it is accessible from other methods this.results = this.results.set(aggregator.task.taskID, result) - // Set a minimum amount of time to wait in the current round to let clients ask for the latest weights - // This is relevant mostly when there are few clients and rounds are almost instantaneous. - await new Promise(resolve => setTimeout(resolve, 1000)) + // The promise resolves once the server received enough contributions (through the handle method) + // and the aggregator aggregated the weights. await result + //Update the server round with the aggregator round + this.rounds = this.rounds.set(aggregator.task.taskID, aggregator.round) + // Create a new promise for the next round void this.storeAggregationResult(aggregator) } @@ -116,6 +120,81 @@ export class Federated extends Server { void this.storeAggregationResult(aggregator) } + /** + * This method is called when a client sends its contribution to the server. The server + * first adds the contribution to the aggregator and then replies with the aggregated weights + * + * @param msg the client message received of type SendPayload which contains the local client's weights + * @param task the task for which the client is contributing + * @param clientId the clientID of the contribution + * @param ws the websocket through which send the aggregated weights + */ + private async addContributionAndSendModel(msg: messages.SendPayload, task: Task, + clientId: client.NodeID, ws: WebSocket) { + const { payload, round } = msg + const aggregator = this.aggregators.get(task.taskID) + + if (!(Array.isArray(payload) && + payload.every((e) => typeof e === 'number'))) { + throw new Error('received invalid weights format') + } + if (aggregator === undefined) { + throw new Error(`received weights for unknown task: ${task.taskID}`) + } + + // It is important to create a promise for the weights BEFORE adding the contribution + // Otherwise the server might go to the next round before sending the + // aggregated weights. Once the server has aggregated the weights it will + // send the new weights to the client. + this.createPromiseForWeights(task, aggregator, ws) + + const serialized = serialization.weights.decode(payload) + // Add the contribution to the aggregator, + // which returns False if the contribution is too old + if (!aggregator.add(clientId, serialized, round, 0)) { + console.info('Dropped contribution from client', clientId, 'for round', round) + } + } + + /** + * This method is called after received a local update. + * It puts the client on hold until the server has aggregated the weights + * by creating a Promise which will resolve once the server has received + * enough contributions. Relying on a promise is useful since clients may + * send their contributions at different times and a promise lets the server + * wait asynchronously for the results + * + * @param task the task to which the client is contributing + * @param aggregator the server aggregator, in order to access the current round + * @param ws the websocket through which send the aggregated weights + */ + private async createPromiseForWeights(task: Task, + aggregator: aggregators.Aggregator, ws: WebSocket){ + const promisedResult = this.results.get(task.taskID) + if (promisedResult === undefined) { + throw new Error(`result promise was not set for task ${task.taskID}`) + } + + // Wait for aggregation result to resolve with timeout, giving the network a time window + // to contribute to the model + void Promise.race([promisedResult, client.utils.timeout()]) + .then((result) => + // Reply with round - 1 because the round number should match the round at which the client sent its weights + // After the server aggregated the weights it also incremented the round so the server replies with round - 1 + [result, aggregator.round - 1] as [WeightsContainer, number]) + .then(async ([result, round]) => + [await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number]) + .then(([serialized, round]) => { + const msg: messages.ReceiveServerPayload = { + type: MessageTypes.ReceiveServerPayload, + round, + payload: serialized + } + ws.send(msgpack.encode(msg)) + }) + .catch(console.error) + } + protected handle ( task: Task, ws: WebSocket, @@ -136,105 +215,66 @@ export class Federated extends Server { const msg = msgpack.decode(data) if (msg.type === MessageTypes.ClientConnected) { - let aggregator = this.aggregators.get(task.taskID) - if (aggregator === undefined) { - aggregator = new aggregators.MeanAggregator(task) - this.aggregators = this.aggregators.set(task.taskID, aggregator) - } - console.info('client', clientId, 'joined', task.taskID) + this.logsAppend(task.taskID, clientId, MessageTypes.ClientConnected, 0) - this.logsAppend(task.taskID, clientId, MessageTypes.ClientConnected, 0) + let aggregator = this.aggregators.get(task.taskID) + if (aggregator === undefined) { + aggregator = new aggregators.MeanAggregator(task) + this.aggregators = this.aggregators.set(task.taskID, aggregator) + } + console.info('client', clientId, 'joined', task.taskID) - const msg: AssignNodeID = { - type: MessageTypes.AssignNodeID, - id: clientId - } - ws.send(msgpack.encode(msg)) + const msg: AssignNodeID = { + type: MessageTypes.AssignNodeID, + id: clientId + } + ws.send(msgpack.encode(msg)) + } else if (msg.type === MessageTypes.SendPayload) { - const { payload, round } = msg - - const aggregator = this.aggregators.get(task.taskID) - - this.logsAppend( - task.taskID, - clientId, - MessageTypes.SendPayload, - msg.round - ) - - if (!( - Array.isArray(payload) && - payload.every((e) => typeof e === 'number') - )) { - throw new Error('received invalid weights format') - } + this.logsAppend(task.taskID, clientId, MessageTypes.SendPayload, msg.round) + + if (model === undefined) { + throw new Error('aggregator model was not set') + } + this.addContributionAndSendModel(msg, task, clientId, ws) - const serialized = serialization.weights.decode(payload) + } else if (msg.type === MessageTypes.ReceiveServerStatistics) { + const statistics = this.informants + .get(task.taskID) + ?.getAllStatistics() - if (aggregator === undefined) { - throw new Error(`received weights for unknown task: ${task.taskID}`) - } + const msg: messages.ReceiveServerStatistics = { + type: MessageTypes.ReceiveServerStatistics, + statistics: statistics ?? {} + } - // TODO @s314cy: add communication rounds to federated learning - if (!aggregator.add(clientId, serialized, round, 0)) { - console.info('Dropped contribution from client', clientId, 'for round', round) - } - } else if (msg.type === MessageTypes.ReceiveServerStatistics) { - const statistics = this.informants - .get(task.taskID) - ?.getAllStatistics() - - const msg: messages.ReceiveServerStatistics = { - type: MessageTypes.ReceiveServerStatistics, - statistics: statistics ?? {} - } + ws.send(msgpack.encode(msg)) - ws.send(msgpack.encode(msg)) } else if (msg.type === MessageTypes.ReceiveServerPayload) { - const aggregator = this.aggregators.get(task.taskID) - if (aggregator === undefined) { - throw new Error(`requesting round of unknown task: ${task.taskID}`) - } + this.logsAppend(task.taskID, clientId, MessageTypes.ReceiveServerPayload, 0) + const aggregator = this.aggregators.get(task.taskID) + if (aggregator === undefined) { + throw new Error(`requesting round of unknown task: ${task.taskID}`) + } + if (model === undefined) { + throw new Error('aggregator model was not set') + } - this.logsAppend(task.taskID, clientId, MessageTypes.ReceiveServerPayload, 0) - if (model === undefined) { - throw new Error('aggregator model was not set') - } - const promisedResult = this.results.get(task.taskID) - if (promisedResult === undefined) { - throw new Error(`result promise was not set for task ${task.taskID}`) - } - - // Wait for aggregation result with timeout, giving the network a time window - // to contribute to the model sent to the requesting client. - void Promise.race([promisedResult, client.utils.timeout()]) - .then((result) => - [result, aggregator.round - 1] as [WeightsContainer, number]) - .then(async ([result, round]) => - [await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number]) - .then(([serialized, round]) => { - const msg: messages.ReceiveServerPayload = { - type: MessageTypes.ReceiveServerPayload, - round, - payload: serialized - } - ws.send(msgpack.encode(msg)) - }) - .catch(console.error) + this.createPromiseForWeights(task, aggregator, ws) } else if (msg.type === MessageTypes.SendMetadata) { - const { round, key, value } = msg + const { round, key, value } = msg - this.logsAppend(task.taskID, clientId, MessageTypes.SendMetadata, round) + this.logsAppend(task.taskID, clientId, MessageTypes.SendMetadata, round) - if (this.metadataMap.hasIn([task.taskID, round, clientId, key])) { - throw new Error('metadata already set') - } - this.metadataMap = this.metadataMap.setIn( - [task, round, clientId, key], - value - ) + if (this.metadataMap.hasIn([task.taskID, round, clientId, key])) { + throw new Error('metadata already set') + } + this.metadataMap = this.metadataMap.setIn( + [task, round, clientId, key], + value + ) } else if (msg.type === MessageTypes.ReceiveServerMetadata) { const key = msg.metadataId const round = Number.parseInt(msg.round, 0) From 7b50bf1e8368193be6201e9adc52312ae25666a8 Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Tue, 6 Feb 2024 13:59:24 +0100 Subject: [PATCH 5/9] Fix linting errors --- discojs/discojs-core/src/client/federated/base.ts | 8 -------- discojs/discojs-core/src/dataset/data/tabular_data.ts | 3 ++- discojs/discojs-core/src/validation/validator.spec.ts | 7 ++++--- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/discojs/discojs-core/src/client/federated/base.ts b/discojs/discojs-core/src/client/federated/base.ts index 89aa42dee..9710843b3 100644 --- a/discojs/discojs-core/src/client/federated/base.ts +++ b/discojs/discojs-core/src/client/federated/base.ts @@ -19,14 +19,6 @@ export class Base extends Client { * by this client class, the server is the only node which we are communicating with. */ public static readonly SERVER_NODE_ID = 'federated-server-node-id' - /** - * Most recent server-fetched round. - */ - private serverRound?: number - /** - * Most recent server-fetched aggregated result. - */ - private serverResult?: WeightsContainer /** * Statistics curated by the federated server. */ diff --git a/discojs/discojs-core/src/dataset/data/tabular_data.ts b/discojs/discojs-core/src/dataset/data/tabular_data.ts index 804e4984b..cb0e92a74 100644 --- a/discojs/discojs-core/src/dataset/data/tabular_data.ts +++ b/discojs/discojs-core/src/dataset/data/tabular_data.ts @@ -21,7 +21,8 @@ export class TabularData extends Data { try { await dataset.iterator() } catch (e) { - throw new Error(`Data input format is not compatible with the chosen task: ${e}`) + console.error('Data input format is not compatible with the chosen task.') + throw (e) } return new TabularData(dataset, task, size) diff --git a/discojs/discojs-core/src/validation/validator.spec.ts b/discojs/discojs-core/src/validation/validator.spec.ts index 0cb65d72a..b9cec0acc 100644 --- a/discojs/discojs-core/src/validation/validator.spec.ts +++ b/discojs/discojs-core/src/validation/validator.spec.ts @@ -1,7 +1,8 @@ import { assert } from 'chai' import fs from 'fs' -import { Task, node, Validator, ConsoleLogger, EmptyMemory, client as clients, data, aggregator, defaultTasks } from '@epfml/discojs-node' +import { Task, node, Validator, ConsoleLogger, EmptyMemory, + client as clients, data, aggregator, defaultTasks } from '@epfml/discojs-node' const simplefaceMock = { taskID: 'simple_face', @@ -55,7 +56,7 @@ describe('validator', () => { `expected accuracy greater than 0.3 but got ${validator.accuracy}` ) console.table(validator.confusionMatrix) - }).timeout(10_000) + }).timeout(15_000) it('works for titanic', async () => { const titanicTask = defaultTasks.titanic.getTask() @@ -86,5 +87,5 @@ describe('validator', () => { validator.accuracy > 0.5, `expected accuracy greater than 0.5 but got ${validator.accuracy}` ) - }).timeout(10_000) + }).timeout(15_000) }) From 85cf6bf8c32e79e714b7c16742c91966015372ce Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Tue, 6 Feb 2024 14:43:17 +0100 Subject: [PATCH 6/9] Fix linting errors --- server/src/router/federated/server.ts | 136 +++++++++++++------------- 1 file changed, 68 insertions(+), 68 deletions(-) diff --git a/server/src/router/federated/server.ts b/server/src/router/federated/server.ts index 1d40d783e..8237c0a90 100644 --- a/server/src/router/federated/server.ts +++ b/server/src/router/federated/server.ts @@ -89,9 +89,8 @@ export class Federated extends Server { } /** - * * Loop creating an aggregation result promise at each round. - * Because clients contribute to the round asynchronously, a promise is used to let them wait + * Because clients contribute to the round asynchronously, a promise is used to let them wait * until the server has aggregated the weights. This loop creates a promise whenever the previous * one resolved and awaits until it resolves. The promise is used in createPromiseForWeights. * @param aggregator The aggregation handler @@ -104,7 +103,7 @@ export class Federated extends Server { // The promise resolves once the server received enough contributions (through the handle method) // and the aggregator aggregated the weights. await result - //Update the server round with the aggregator round + // Update the server round with the aggregator round this.rounds = this.rounds.set(aggregator.task.taskID, aggregator.round) // Create a new promise for the next round void this.storeAggregationResult(aggregator) @@ -123,14 +122,14 @@ export class Federated extends Server { /** * This method is called when a client sends its contribution to the server. The server * first adds the contribution to the aggregator and then replies with the aggregated weights - * + * * @param msg the client message received of type SendPayload which contains the local client's weights * @param task the task for which the client is contributing * @param clientId the clientID of the contribution * @param ws the websocket through which send the aggregated weights */ - private async addContributionAndSendModel(msg: messages.SendPayload, task: Task, - clientId: client.NodeID, ws: WebSocket) { + private async addContributionAndSendModel (msg: messages.SendPayload, task: Task, + clientId: client.NodeID, ws: WebSocket): Promise { const { payload, round } = msg const aggregator = this.aggregators.get(task.taskID) @@ -141,15 +140,17 @@ export class Federated extends Server { if (aggregator === undefined) { throw new Error(`received weights for unknown task: ${task.taskID}`) } - + // It is important to create a promise for the weights BEFORE adding the contribution - // Otherwise the server might go to the next round before sending the + // Otherwise the server might go to the next round before sending the // aggregated weights. Once the server has aggregated the weights it will // send the new weights to the client. + // Use the void keyword to explicity avoid waiting for the promise to resolve this.createPromiseForWeights(task, aggregator, ws) - + .catch(console.error) + const serialized = serialization.weights.decode(payload) - // Add the contribution to the aggregator, + // Add the contribution to the aggregator, // which returns False if the contribution is too old if (!aggregator.add(clientId, serialized, round, 0)) { console.info('Dropped contribution from client', clientId, 'for round', round) @@ -159,17 +160,19 @@ export class Federated extends Server { /** * This method is called after received a local update. * It puts the client on hold until the server has aggregated the weights - * by creating a Promise which will resolve once the server has received - * enough contributions. Relying on a promise is useful since clients may - * send their contributions at different times and a promise lets the server + * by creating a Promise which will resolve once the server has received + * enough contributions. Relying on a promise is useful since clients may + * send their contributions at different times and a promise lets the server * wait asynchronously for the results - * + * * @param task the task to which the client is contributing * @param aggregator the server aggregator, in order to access the current round * @param ws the websocket through which send the aggregated weights */ - private async createPromiseForWeights(task: Task, - aggregator: aggregators.Aggregator, ws: WebSocket){ + private async createPromiseForWeights ( + task: Task, + aggregator: aggregators.Aggregator, + ws: WebSocket): Promise { const promisedResult = this.results.get(task.taskID) if (promisedResult === undefined) { throw new Error(`result promise was not set for task ${task.taskID}`) @@ -181,9 +184,9 @@ export class Federated extends Server { .then((result) => // Reply with round - 1 because the round number should match the round at which the client sent its weights // After the server aggregated the weights it also incremented the round so the server replies with round - 1 - [result, aggregator.round - 1] as [WeightsContainer, number]) + [result, aggregator.round - 1] as [WeightsContainer, number]) .then(async ([result, round]) => - [await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number]) + [await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number]) .then(([serialized, round]) => { const msg: messages.ReceiveServerPayload = { type: MessageTypes.ReceiveServerPayload, @@ -215,66 +218,63 @@ export class Federated extends Server { const msg = msgpack.decode(data) if (msg.type === MessageTypes.ClientConnected) { - this.logsAppend(task.taskID, clientId, MessageTypes.ClientConnected, 0) + this.logsAppend(task.taskID, clientId, MessageTypes.ClientConnected, 0) - let aggregator = this.aggregators.get(task.taskID) - if (aggregator === undefined) { - aggregator = new aggregators.MeanAggregator(task) - this.aggregators = this.aggregators.set(task.taskID, aggregator) - } - console.info('client', clientId, 'joined', task.taskID) + let aggregator = this.aggregators.get(task.taskID) + if (aggregator === undefined) { + aggregator = new aggregators.MeanAggregator(task) + this.aggregators = this.aggregators.set(task.taskID, aggregator) + } + console.info('client', clientId, 'joined', task.taskID) - const msg: AssignNodeID = { - type: MessageTypes.AssignNodeID, - id: clientId - } - ws.send(msgpack.encode(msg)) - + const msg: AssignNodeID = { + type: MessageTypes.AssignNodeID, + id: clientId + } + ws.send(msgpack.encode(msg)) } else if (msg.type === MessageTypes.SendPayload) { - this.logsAppend(task.taskID, clientId, MessageTypes.SendPayload, msg.round) - - if (model === undefined) { - throw new Error('aggregator model was not set') - } - this.addContributionAndSendModel(msg, task, clientId, ws) - - } else if (msg.type === MessageTypes.ReceiveServerStatistics) { - const statistics = this.informants - .get(task.taskID) - ?.getAllStatistics() - - const msg: messages.ReceiveServerStatistics = { - type: MessageTypes.ReceiveServerStatistics, - statistics: statistics ?? {} - } + this.logsAppend(task.taskID, clientId, MessageTypes.SendPayload, msg.round) - ws.send(msgpack.encode(msg)) + if (model === undefined) { + throw new Error('aggregator model was not set') + } + this.addContributionAndSendModel(msg, task, clientId, ws) + .catch(console.error) + } else if (msg.type === MessageTypes.ReceiveServerStatistics) { + const statistics = this.informants + .get(task.taskID) + ?.getAllStatistics() + + const msg: messages.ReceiveServerStatistics = { + type: MessageTypes.ReceiveServerStatistics, + statistics: statistics ?? {} + } + ws.send(msgpack.encode(msg)) } else if (msg.type === MessageTypes.ReceiveServerPayload) { - this.logsAppend(task.taskID, clientId, MessageTypes.ReceiveServerPayload, 0) - const aggregator = this.aggregators.get(task.taskID) - if (aggregator === undefined) { - throw new Error(`requesting round of unknown task: ${task.taskID}`) - } - if (model === undefined) { - throw new Error('aggregator model was not set') - } - - + this.logsAppend(task.taskID, clientId, MessageTypes.ReceiveServerPayload, 0) + const aggregator = this.aggregators.get(task.taskID) + if (aggregator === undefined) { + throw new Error(`requesting round of unknown task: ${task.taskID}`) + } + if (model === undefined) { + throw new Error('aggregator model was not set') + } - this.createPromiseForWeights(task, aggregator, ws) + this.createPromiseForWeights(task, aggregator, ws) + .catch(console.error) } else if (msg.type === MessageTypes.SendMetadata) { - const { round, key, value } = msg + const { round, key, value } = msg - this.logsAppend(task.taskID, clientId, MessageTypes.SendMetadata, round) + this.logsAppend(task.taskID, clientId, MessageTypes.SendMetadata, round) - if (this.metadataMap.hasIn([task.taskID, round, clientId, key])) { - throw new Error('metadata already set') - } - this.metadataMap = this.metadataMap.setIn( - [task, round, clientId, key], - value - ) + if (this.metadataMap.hasIn([task.taskID, round, clientId, key])) { + throw new Error('metadata already set') + } + this.metadataMap = this.metadataMap.setIn( + [task, round, clientId, key], + value + ) } else if (msg.type === MessageTypes.ReceiveServerMetadata) { const key = msg.metadataId const round = Number.parseInt(msg.round, 0) From a65205d3561ee768b555b7a56d0d3892056e31b8 Mon Sep 17 00:00:00 2001 From: Julien Vignoud <33122365+JulienVig@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:30:14 +0100 Subject: [PATCH 7/9] Wrap error into error message MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rousset --- discojs/discojs-core/src/dataset/data/tabular_data.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/discojs/discojs-core/src/dataset/data/tabular_data.ts b/discojs/discojs-core/src/dataset/data/tabular_data.ts index cb0e92a74..c15106cdc 100644 --- a/discojs/discojs-core/src/dataset/data/tabular_data.ts +++ b/discojs/discojs-core/src/dataset/data/tabular_data.ts @@ -21,8 +21,7 @@ export class TabularData extends Data { try { await dataset.iterator() } catch (e) { - console.error('Data input format is not compatible with the chosen task.') - throw (e) + throw new Error('Data input format is not compatible with the chosen task.', { cause: e }) } return new TabularData(dataset, task, size) From 2b880ae504bee29f003f57712dde3a6b960293b3 Mon Sep 17 00:00:00 2001 From: Julien Vignoud <33122365+JulienVig@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:34:03 +0100 Subject: [PATCH 8/9] Make internal method private MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rousset --- discojs/discojs-core/src/client/federated/base.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discojs/discojs-core/src/client/federated/base.ts b/discojs/discojs-core/src/client/federated/base.ts index 9710843b3..515cbfd0f 100644 --- a/discojs/discojs-core/src/client/federated/base.ts +++ b/discojs/discojs-core/src/client/federated/base.ts @@ -102,7 +102,7 @@ export class Base extends Client { * Waits for the server's result for its current (most recent) round and add it to our aggregator. * Updates the aggregator's round if it's behind the server's. */ - async receiveResult (): Promise { + private async receiveResult (): Promise { try { const { payload, round } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload) const serverRound = round From ac3b5d3eb44e0e49a1a6b2f8ef689eaabef35f20 Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Thu, 8 Feb 2024 12:48:10 +0100 Subject: [PATCH 9/9] Remove error cause because of incompatibility --- discojs/discojs-core/src/dataset/data/tabular_data.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/discojs/discojs-core/src/dataset/data/tabular_data.ts b/discojs/discojs-core/src/dataset/data/tabular_data.ts index c15106cdc..cb0e92a74 100644 --- a/discojs/discojs-core/src/dataset/data/tabular_data.ts +++ b/discojs/discojs-core/src/dataset/data/tabular_data.ts @@ -21,7 +21,8 @@ export class TabularData extends Data { try { await dataset.iterator() } catch (e) { - throw new Error('Data input format is not compatible with the chosen task.', { cause: e }) + console.error('Data input format is not compatible with the chosen task.') + throw (e) } return new TabularData(dataset, task, size)