From f4725da256862b56f4eedbc1f39be244a988be7a Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Tue, 13 Feb 2024 17:15:47 +0100 Subject: [PATCH 1/3] Titanic: lower random weight init accuracy expectations and assert actual training and validation accuracy in end-to-end test --- .../src/validation/validator.spec.ts | 20 ++++++++-------- server/tests/e2e/federated.spec.ts | 24 ++++++++++++------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/discojs/discojs-core/src/validation/validator.spec.ts b/discojs/discojs-core/src/validation/validator.spec.ts index b9cec0acc..722e80a92 100644 --- a/discojs/discojs-core/src/validation/validator.spec.ts +++ b/discojs/discojs-core/src/validation/validator.spec.ts @@ -23,7 +23,7 @@ const simplefaceMock = { } as unknown as Task describe('validator', () => { - it('works for simple_face', async () => { + it('simple_face validator', async () => { const dir = '../../example_training_data/simple_face/' const files: string[][] = ['child/', 'adult/'] .map((subdir: string) => fs.readdirSync(dir + subdir) @@ -49,16 +49,16 @@ describe('validator', () => { } assert( validator.visitedSamples === data.size, - `expected ${size} visited samples but got ${validator.visitedSamples}` + `Expected ${size} visited samples but got ${validator.visitedSamples}` ) assert( validator.accuracy > 0.3, - `expected accuracy greater than 0.3 but got ${validator.accuracy}` + `Expected random weight init accuracy greater than 0.3 but got ${validator.accuracy}` ) console.table(validator.confusionMatrix) }).timeout(15_000) - it('works for titanic', async () => { + it('titanic validator', async () => { const titanicTask = defaultTasks.titanic.getTask() const files = ['../../example_training_data/titanic_train.csv'] const data: data.Data = (await new node.data.NodeTabularLoader(titanicTask, ',').loadAll(files, { @@ -66,9 +66,9 @@ describe('validator', () => { labels: titanicTask.trainingInformation.outputColumns, shuffle: false })).train - const buffer = new aggregator.MeanAggregator(titanicTask) - const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, buffer) - buffer.setModel(await client.getLatestModel()) + const aggregator = new aggregator.MeanAggregator(titanicTask) + const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, aggregator) + aggregator.setModel(await client.getLatestModel()) const validator = new Validator(titanicTask, new ConsoleLogger(), new EmptyMemory(), @@ -81,11 +81,11 @@ describe('validator', () => { await data.dataset.forEachAsync(() => size+=1) assert( validator.visitedSamples === size, - `expected ${size} visited samples but got ${validator.visitedSamples}` + `Expected ${size} visited samples but got ${validator.visitedSamples}` ) assert( - validator.accuracy > 0.5, - `expected accuracy greater than 0.5 but got ${validator.accuracy}` + validator.accuracy > 0.3, + `Expected random weight init accuracy greater than 0.3 but got ${validator.accuracy}` ) }).timeout(15_000) }) diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index b4f50e258..2dc7607ea 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -4,7 +4,10 @@ import { Server } from 'node:http' import { Range } from 'immutable' import { assert } from 'chai' -import { WeightsContainer, node, Disco, TrainingSchemes, client as clients, aggregator as aggregators, defaultTasks } from '@epfml/discojs-node' +import { + WeightsContainer, node, Disco, TrainingSchemes, client as clients, + aggregator as aggregators, informant, defaultTasks +} from '@epfml/discojs-node' import { getClient, startServer } from '../utils' @@ -59,14 +62,23 @@ describe('end-to-end federated', function () { const aggregator = new aggregators.MeanAggregator(titanicTask) const client = await getClient(clients.federated.FederatedClient, server, titanicTask, aggregator) - const disco = new Disco(titanicTask, { scheme: SCHEME, client, aggregator }) + const trainingInformant = new informant.FederatedInformant(titanicTask, 10) + const disco = new Disco(titanicTask, { scheme: SCHEME, client, aggregator, informant: trainingInformant }) await disco.fit(data) await disco.close() - + if (aggregator.model === undefined) { throw new Error('model was not set') } + assert( + trainingInformant.trainingAccuracy() > 0.6, + `expected training accuracy greater than 0.6 but got ${trainingInformant.trainingAccuracy()}` + ) + assert( + trainingInformant.validationAccuracy() > 0.6, + `expected validation accuracy greater than 0.6 but got ${trainingInformant.validationAccuracy()}` + ) return WeightsContainer.from(aggregator.model) } @@ -77,10 +89,6 @@ describe('end-to-end federated', function () { it('two titanic users reach consensus', async () => { const [m1, m2] = await Promise.all([titanicUser(), titanicUser()]) - assert.isTrue( - m1.weights.some((x) => x.isNaN()) || - m2.weights.some((x) => x.isNaN()) || - m1.equals(m2) - ) + assert.isTrue(m1.equals(m2)) }) }) From dcc4067e42993b773bf684b0c402f3ffa63565fd Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Tue, 13 Feb 2024 17:18:45 +0100 Subject: [PATCH 2/3] Fix import name shadowing --- .../discojs-core/src/validation/validator.spec.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/discojs/discojs-core/src/validation/validator.spec.ts b/discojs/discojs-core/src/validation/validator.spec.ts index 722e80a92..1d90cef3c 100644 --- a/discojs/discojs-core/src/validation/validator.spec.ts +++ b/discojs/discojs-core/src/validation/validator.spec.ts @@ -32,9 +32,9 @@ describe('validator', () => { const data: data.Data = (await new node.data.NodeImageLoader(simplefaceMock) .loadAll(files.flat(), { labels })).train - const buffer = new aggregator.MeanAggregator(simplefaceMock) - const client = new clients.Local(new URL('http://localhost:8080'), simplefaceMock, buffer) - buffer.setModel(await client.getLatestModel()) + const meanAggregator = new aggregator.MeanAggregator(simplefaceMock) + const client = new clients.Local(new URL('http://localhost:8080'), simplefaceMock, meanAggregator) + meanAggregator.setModel(await client.getLatestModel()) const validator = new Validator( simplefaceMock, new ConsoleLogger(), @@ -66,9 +66,9 @@ describe('validator', () => { labels: titanicTask.trainingInformation.outputColumns, shuffle: false })).train - const aggregator = new aggregator.MeanAggregator(titanicTask) - const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, aggregator) - aggregator.setModel(await client.getLatestModel()) + const meanAggregator = new aggregator.MeanAggregator(titanicTask) + const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, meanAggregator) + meanAggregator.setModel(await client.getLatestModel()) const validator = new Validator(titanicTask, new ConsoleLogger(), new EmptyMemory(), From 724ae4e38074e0b2958cb0ae2dca538e58724943 Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Tue, 13 Feb 2024 17:25:51 +0100 Subject: [PATCH 3/3] Fix server lint error --- server/tests/e2e/federated.spec.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 2dc7607ea..acac688e2 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -67,7 +67,7 @@ describe('end-to-end federated', function () { await disco.fit(data) await disco.close() - + if (aggregator.model === undefined) { throw new Error('model was not set') }