Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix training related failures #616

Merged
merged 11 commits into from
Feb 8, 2024
4 changes: 2 additions & 2 deletions discojs/discojs-core/src/aggregator/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ export abstract class Base<T> {
*/
protected readonly roundCutoff = 0,
/**
* The number of communication rounds occuring during any given aggregation round.
* The number of communication rounds occurring during any given aggregation round.
*/
public readonly communicationRounds = 1
) {
Expand Down Expand Up @@ -272,7 +272,7 @@ export abstract class Base<T> {
}

/**
* The current commnication round.
* The current communication round.
*/
get communicationRound (): number {
return this._communicationRound
Expand Down
44 changes: 15 additions & 29 deletions discojs/discojs-core/src/client/federated/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@ export class Base extends Client {
* by this client class, the server is the only node which we are communicating with.
*/
public static readonly SERVER_NODE_ID = 'federated-server-node-id'
/**
* Most recent server-fetched round.
*/
private serverRound?: number
/**
* Most recent server-fetched aggregated result.
*/
private serverResult?: WeightsContainer
tharvik marked this conversation as resolved.
Show resolved Hide resolved
/**
* Statistics curated by the federated server.
*/
Expand Down Expand Up @@ -92,41 +84,37 @@ export class Base extends Client {

/**
* Send a message containing our local weight updates to the federated server.
* And waits for the server to reply with the most recent aggregated weights
* @param weights The weight updates to send
*/
async sendPayload (payload: WeightsContainer): Promise<void> {
private async sendPayloadAndReceiveResult (payload: WeightsContainer): Promise<WeightsContainer|undefined> {
const msg: messages.SendPayload = {
type: type.SendPayload,
payload: await serialization.weights.encode(payload),
round: this.aggregator.round
}
this.server.send(msg)
// It is important than the client immediately awaits the server result or it may miss it
return await this.receiveResult()
}

/**
* Fetches the server's result for its current (most recent) round and add it to our aggregator.
* Waits for the server's result for its current (most recent) round and add it to our aggregator.
* Updates the aggregator's round if it's behind the server's.
*/
async receiveResult (): Promise<void> {
this.serverRound = undefined
this.serverResult = undefined

const msg: messages.MessageBase = {
type: type.ReceiveServerPayload
}
this.server.send(msg)

async receiveResult (): Promise<WeightsContainer|undefined> {
JulienVig marked this conversation as resolved.
Show resolved Hide resolved
try {
const { payload, round } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload)
this.serverRound = round
const serverRound = round

// Store the server result only if it is not stale
if (this.aggregator.round <= round) {
this.serverResult = serialization.weights.decode(payload)
const serverResult = serialization.weights.decode(payload)
// Update the local round to match the server's
if (this.aggregator.round < this.serverRound) {
this.aggregator.setRound(this.serverRound)
if (this.aggregator.round < serverRound) {
this.aggregator.setRound(serverRound)
}
return serverResult
}
} catch (e) {
console.error(e)
Expand Down Expand Up @@ -226,13 +214,11 @@ export class Base extends Client {
throw new Error('local aggregation result was not set')
}

// Send our contribution to the server
await this.sendPayload(this.aggregator.makePayloads(weights).first())
// Fetch the server result
await this.receiveResult()
// Send our local contribution to the server
// and receive the most recent weights as an answer to our contribution
const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first())
tharvik marked this conversation as resolved.
Show resolved Hide resolved

// TODO @s314cy: add communication rounds to federated learning
if (this.serverResult !== undefined && this.aggregator.add(Base.SERVER_NODE_ID, this.serverResult, round, 0)) {
if (serverResult !== undefined && this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
// Regular case: the server sends us its aggregation result which will serve our
// own aggregation result.
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Task, tf } from '../../..'
import { List } from 'immutable'
import { PreprocessingFunction } from './base'

Expand All @@ -9,7 +10,25 @@ export enum TabularPreprocessing {
Normalize
}

interface TabularEntry extends tf.TensorContainerObject {
JulienVig marked this conversation as resolved.
Show resolved Hide resolved
xs: number[]
ys: tf.Tensor1D | number | undefined
}

const sanitize: PreprocessingFunction = {
type: TabularPreprocessing.Sanitize,
apply: (entry: tf.TensorContainer, task: Task): tf.TensorContainer => {
const { xs, ys } = entry as TabularEntry
return {
xs: xs.map(i => i === undefined ? 0 : i),
ys: ys
}
}
}

/**
* Available tabular preprocessing functions.
*/
export const AVAILABLE_PREPROCESSING = List<PreprocessingFunction>()
export const AVAILABLE_PREPROCESSING = List([
sanitize]
).sortBy((e) => e.type)
3 changes: 2 additions & 1 deletion discojs/discojs-core/src/dataset/data/tabular_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ export class TabularData extends Data {
try {
await dataset.iterator()
} catch (e) {
throw new Error('Data input format is not compatible with the chosen task')
console.error('Data input format is not compatible with the chosen task.')
throw (e)
tharvik marked this conversation as resolved.
Show resolved Hide resolved
}

return new TabularData(dataset, task, size)
Expand Down
5 changes: 3 additions & 2 deletions discojs/discojs-core/src/default_tasks/titanic.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { tf, Task, TaskProvider } from '..'
import { tf, Task, TaskProvider, data } from '..'

export const titanic: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -49,7 +49,8 @@ export const titanic: TaskProvider = {
roundDuration: 10,
validationSplit: 0.2,
batchSize: 30,
preprocessingFunctions: [],
preprocessingFunctions: [data.TabularPreprocessing.Sanitize],
learningRate: 0.001,
modelCompileData: {
optimizer: 'sgd',
loss: 'binaryCrossentropy',
Expand Down
54 changes: 33 additions & 21 deletions discojs/discojs-core/src/validation/validator.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { assert } from 'chai'
import fs from 'fs'

import { Task, node, Validator, ConsoleLogger, EmptyMemory, client as clients, data, aggregator } from '@epfml/discojs-node'
import { Task, node, Validator, ConsoleLogger, EmptyMemory,
client as clients, data, aggregator, defaultTasks } from '@epfml/discojs-node'

const simplefaceMock = {
taskID: 'simple_face',
Expand Down Expand Up @@ -55,25 +56,36 @@ describe('validator', () => {
`expected accuracy greater than 0.3 but got ${validator.accuracy}`
)
console.table(validator.confusionMatrix)
}).timeout(10_000)
}).timeout(15_000)

// TODO: fix titanic model (nan accuracy)
// it('works for titanic', async () => {
// const data: Data = await new NodeTabularLoader(titanic.task, ',')
// .loadAll(['../../example_training_data/titanic.csv'], {
// features: titanic.task.trainingInformation?.inputColumns,
// labels: titanic.task.trainingInformation?.outputColumns
// })
// const validator = new Validator(titanic.task, new ConsoleLogger(), titanic.model())
// await validator.assess(data)

// assert(
// validator.visitedSamples() === data.size,
// `expected ${TITANIC_SAMPLES} visited samples but got ${validator.visitedSamples()}`
// )
// assert(
// validator.accuracy() > 0.5,
// `expected accuracy greater than 0.5 but got ${validator.accuracy()}`
// )
// })
it('works for titanic', async () => {
const titanicTask = defaultTasks.titanic.getTask()
const files = ['../../example_training_data/titanic_train.csv']
const data: data.Data = (await new node.data.NodeTabularLoader(titanicTask, ',').loadAll(files, {
features: titanicTask.trainingInformation.inputColumns,
labels: titanicTask.trainingInformation.outputColumns,
shuffle: false
})).train
const buffer = new aggregator.MeanAggregator(titanicTask)
const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, buffer)
buffer.setModel(await client.getLatestModel())
const validator = new Validator(titanicTask,
new ConsoleLogger(),
new EmptyMemory(),
undefined,
client)
await validator.assess(data)
// data.size is undefined because tfjs handles dataset lazily
// instead we count the dataset size manually
let size = 0
await data.dataset.forEachAsync(() => size+=1)
assert(
validator.visitedSamples === size,
`expected ${size} visited samples but got ${validator.visitedSamples}`
)
assert(
validator.accuracy > 0.5,
`expected accuracy greater than 0.5 but got ${validator.accuracy}`
)
}).timeout(15_000)
})
13 changes: 12 additions & 1 deletion docs/node_example/data.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import fs from 'fs'
import Rand from 'rand-seed'

import { node, data, Task } from '@epfml/discojs-node'
import { node, data, Task, defaultTasks } from '@epfml/discojs-node'

const rand = new Rand('1234')

Expand Down Expand Up @@ -45,3 +45,14 @@ export async function loadData (task: Task): Promise<data.DataSplit> {

return await new node.data.NodeImageLoader(task).loadAll(files, { labels: labels })
}

export async function loadTitanicData (task:Task): Promise<data.
DataSplit> {
const files = ['../../example_training_data/titanic_train.csv']
const titanicTask = defaultTasks.titanic.getTask()
return await new node.data.NodeTabularLoader(task, ',').loadAll(files, {
features: titanicTask.trainingInformation.inputColumns,
labels: titanicTask.trainingInformation.outputColumns,
shuffle: false
})
}
20 changes: 8 additions & 12 deletions docs/node_example/example.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import { data, Disco, fetchTasks, Task } from '@epfml/discojs-node'

import { startServer } from './start_server'
import { loadData } from './data'
import { loadTitanicData } from './data'

/**
* Example of discojs API, we load data, build the appropriate loggers, the disco object
Expand All @@ -18,24 +16,22 @@ async function runUser (url: URL, task: Task, dataset: data.DataSplit): Promise<

async function main (): Promise<void> {

const [server, serverUrl] = await startServer()
// First have a server instance running before running this script
const serverUrl = new URL('http://localhost:8080/')

const tasks = await fetchTasks(serverUrl)

// Choose your task to train
const task = tasks.get('simple_face') as Task
const task = tasks.get('titanic') as Task

const dataset = await loadData(task)
const dataset = await loadTitanicData(task)

// Add more users to the list to simulate more clients
await Promise.all([
runUser(serverUrl, task, dataset),
runUser(serverUrl, task, dataset)
runUser(serverUrl, task, dataset),
runUser(serverUrl, task, dataset),
])

await new Promise((resolve, reject) => {
server.once('close', resolve)
server.close(reject)
})
}

main().catch(console.error)
33 changes: 0 additions & 33 deletions docs/node_example/start_server.ts
tharvik marked this conversation as resolved.
Show resolved Hide resolved

This file was deleted.

2 changes: 1 addition & 1 deletion docs/node_example/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"declaration": true,

"typeRoots": ["node_modules/@types", "discojs-core/types"]
"typeRoots": ["node_modules/@types", "../../discojs/discojs-core/types"]
tharvik marked this conversation as resolved.
Show resolved Hide resolved
},
"include": ["*.ts"],
"exclude": ["node_modules"]
Expand Down
Loading
Loading