Skip to content

Commit

Permalink
discojs-core/aggregator: rm Task dep
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Mar 13, 2024
1 parent 7e7f272 commit dbdb764
Show file tree
Hide file tree
Showing 19 changed files with 99 additions and 142 deletions.
23 changes: 9 additions & 14 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,8 @@ import { saveLog } from './utils'
import { getTaskData } from './data'
import { args } from './args'

const NUMBER_OF_USERS = args.numberOfUsers
const TASK = args.task

const infoText = `\nStarted federated training of ${TASK.id}`
console.log(infoText)

console.log({ args })

async function runUser (task: Task, url: URL, data: data.DataSplit): Promise<TrainerLog> {
const client = new clients.federated.FederatedClient(url, task, new aggregators.MeanAggregator(TASK))
const client = new clients.federated.FederatedClient(url, task, new aggregators.MeanAggregator())

// force the federated scheme
const scheme = TrainingSchemes.FEDERATED
Expand All @@ -28,17 +20,20 @@ async function runUser (task: Task, url: URL, data: data.DataSplit): Promise<Tra
return await disco.logs()
}

async function main (): Promise<void> {
async function main (task: Task, numberOfUsers: number): Promise<void> {
console.log(`Started federated training of ${task.id}`)
console.log({ args })

const [server, url] = await startServer()

const data = await getTaskData(TASK)
const data = await getTaskData(task)

const logs = await Promise.all(
Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, url, data)).toArray()
Range(0, numberOfUsers).map(async (_) => await runUser(task, url, data)).toArray()
)

if (args.save) {
const fileName = `${TASK.id}_${NUMBER_OF_USERS}users.csv`
const fileName = `${task.id}_${numberOfUsers}users.csv`
saveLog(logs, fileName)
}
console.log('Shutting down the server...')
Expand All @@ -48,4 +43,4 @@ async function main (): Promise<void> {
})
}

main().catch(console.error)
main(args.task, args.numberOfUsers).catch(console.error)
6 changes: 1 addition & 5 deletions discojs/discojs-core/src/aggregator/base.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Map, Set } from 'immutable'

import type { client, Model, Task, AsyncInformant } from '..'
import type { client, Model, AsyncInformant } from '..'

import { EventEmitter } from '../utils/event_emitter'

Expand Down Expand Up @@ -54,10 +54,6 @@ export abstract class Base<T> {
protected _communicationRound = 0

constructor (
/**
* The task for which the aggregator should be created.
*/
public readonly task: Task,
/**
* The Model whose weights are updated on aggregation.
*/
Expand Down
6 changes: 3 additions & 3 deletions discojs/discojs-core/src/aggregator/get.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export function getAggregator (task: Task): aggregator.Aggregator {
const error = new Error('not implemented')
switch (task.trainingInformation.aggregator) {
case AggregatorChoice.MEAN:
return new aggregator.MeanAggregator(task)
return new aggregator.MeanAggregator()
case AggregatorChoice.ROBUST:
throw error
case AggregatorChoice.BANDIT:
Expand All @@ -28,8 +28,8 @@ export function getAggregator (task: Task): aggregator.Aggregator {
if (task.trainingInformation.scheme !== 'decentralized') {
throw new Error('secure aggregation is currently supported for decentralized only')
}
return new aggregator.SecureAggregator(task)
return new aggregator.SecureAggregator()
default:
return new aggregator.MeanAggregator(task)
return new aggregator.MeanAggregator()
}
}
4 changes: 2 additions & 2 deletions discojs/discojs-core/src/aggregator/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { type WeightsContainer } from '../weights'
import { type Base } from './base'
import type { WeightsContainer } from '../weights'
import type { Base } from './base'

export { Base as AggregatorBase, AggregationStep } from './base'
export { MeanAggregator } from './mean'
Expand Down
20 changes: 9 additions & 11 deletions discojs/discojs-core/src/aggregator/mean.spec.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import { assert, expect } from 'chai'
import type { Map } from 'immutable'

import type { client, Model, Task } from '..'
import type { client, Model } from '..'
import { aggregator, defaultTasks } from '..'
import { AggregationStep } from './base'

const task = defaultTasks.titanic.getTask()
const model = defaultTasks.titanic.getModel()
const id = 'a'
const weights = [1, 2, 3]
Expand All @@ -14,12 +13,11 @@ const bufferCapacity = weights.length

export class MockMeanAggregator extends aggregator.AggregatorBase<number> {
constructor (
task: Task,
model: Model,
private readonly threshold: number,
roundCutoff = 0
) {
super(task, model, roundCutoff, 1)
super(model, roundCutoff, 1)
}

isFull (): boolean {
Expand Down Expand Up @@ -56,36 +54,36 @@ export class MockMeanAggregator extends aggregator.AggregatorBase<number> {
describe('mean aggregator tests', () => {
it('adding weight update with old time stamp returns false', async () => {
const t0 = -1
const aggregator = new MockMeanAggregator(task, await model, bufferCapacity)
const aggregator = new MockMeanAggregator(await model, bufferCapacity)
assert.isFalse(aggregator.add(id, weights[0], t0))
})

it('adding weight update with recent time stamp returns true', async () => {
const aggregator = new MockMeanAggregator(task, await model, bufferCapacity)
const aggregator = new MockMeanAggregator(await model, bufferCapacity)
const t0 = Date.now()
assert.isTrue(aggregator.add(id, weights[0], t0))
})

it('aggregator returns false if it is not full', async () => {
const aggregator = new MockMeanAggregator(task, await model, bufferCapacity)
const aggregator = new MockMeanAggregator(await model, bufferCapacity)
assert.isFalse(aggregator.isFull())
})

it('aggregator with standard cutoff = 0', async () => {
const aggregator = new MockMeanAggregator(task, await model, bufferCapacity)
const aggregator = new MockMeanAggregator(await model, bufferCapacity)
assert.isTrue(aggregator.isWithinRoundCutoff(0))
assert.isFalse(aggregator.isWithinRoundCutoff(-1))
})

it('aggregator with different cutoff = 1', async () => {
const aggregator = new MockMeanAggregator(task, await model, bufferCapacity, 1)
const aggregator = new MockMeanAggregator(await model, bufferCapacity, 1)
assert.isTrue(aggregator.isWithinRoundCutoff(0))
assert.isTrue(aggregator.isWithinRoundCutoff(-1))
assert.isFalse(aggregator.isWithinRoundCutoff(-2))
})

it('adding enough updates to buffer launches aggregator and updates weights', async () => {
const aggregator = new MockMeanAggregator(task, await model, bufferCapacity)
const aggregator = new MockMeanAggregator(await model, bufferCapacity)
const mockAggregatedWeights = 2

const result = aggregator.receiveResult()
Expand All @@ -98,7 +96,7 @@ describe('mean aggregator tests', () => {
})

it('testing two full cycles (adding x2 buffer capacity)', async () => {
const aggregator = new MockMeanAggregator(task, await model, bufferCapacity, 0)
const aggregator = new MockMeanAggregator(await model, bufferCapacity, 0)

let mockAggregatedWeights = 2
let result = aggregator.receiveResult()
Expand Down
5 changes: 2 additions & 3 deletions discojs/discojs-core/src/aggregator/mean.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Map } from 'immutable'

import { AggregationStep, Base as Aggregator } from './base'
import type { Model, Task, WeightsContainer, client } from '..'
import type { Model, WeightsContainer, client } from '..'
import { aggregation } from '..'

/**
Expand All @@ -16,12 +16,11 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
public readonly threshold: number

constructor (
task: Task,
model?: Model,
roundCutoff = 0,
threshold = 1
) {
super(task, model, roundCutoff, 1)
super(model, roundCutoff, 1)

// Default threshold is 100% of node participation
if (threshold === undefined) {
Expand Down
22 changes: 12 additions & 10 deletions discojs/discojs-core/src/aggregator/robust.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Base as Aggregator } from './base'
import { type client, type WeightsContainer } from '..'
import type { client, Model, WeightsContainer } from '..'

import { type Map } from 'immutable'
import type { Map } from 'immutable'

export type Momentum = WeightsContainer

Expand All @@ -11,6 +11,15 @@ export class RobustAggregator extends Aggregator<WeightsContainer> {
// TODO @s314y: move to task definition
private readonly beta = 1

constructor (
private readonly tauPercentile: number,
model?: Model,
roundCutoff?: number,
communicationRounds?: number
) {
super(model, roundCutoff, communicationRounds)
}

add (nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound: number): boolean {
if (this.isWithinRoundCutoff(round)) {
const stale = this.contributions.get(communicationRound)
Expand All @@ -27,14 +36,7 @@ export class RobustAggregator extends Aggregator<WeightsContainer> {
}

aggregate (): void {
if (this.task.trainingInformation.tauPercentile === undefined) {
throw new Error('task doesn\'t provide tau percentile')
}
// this.emit(aggregation.avgClippingWeights(
// this.contributions.values(),
// undefined as unknown as WeightsContainer,
// this.task.trainingInformation.tauPercentile
// ))
throw new Error('not implemented')
}

makePayloads (weights: WeightsContainer): Map<client.NodeID, WeightsContainer> {
Expand Down
5 changes: 2 additions & 3 deletions discojs/discojs-core/src/aggregator/secure.spec.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import { List, Set, Range } from 'immutable'
import { assert } from 'chai'

import { aggregator as aggregators, aggregation, WeightsContainer, defaultTasks } from '@epfml/discojs-core'
import { aggregator as aggregators, aggregation, WeightsContainer } from '@epfml/discojs-core'

describe('secret shares test', function () {
const epsilon = 1e-4
const task = defaultTasks.cifar10.getTask()

const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10])
const secrets = List.of(
Expand All @@ -17,7 +16,7 @@ describe('secret shares test', function () {
function buildShares (): List<List<WeightsContainer>> {
const nodes = Set(secrets.keys()).map(String)
return secrets.map((secret) => {
const aggregator = new aggregators.SecureAggregator(task)
const aggregator = new aggregators.SecureAggregator()
aggregator.setNodes(nodes)
return aggregator.generateAllShares(secret)
})
Expand Down
12 changes: 4 additions & 8 deletions discojs/discojs-core/src/aggregator/secure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Map, List, Range } from 'immutable'
import tf from '@tensorflow/tfjs'

import { AggregationStep, Base as Aggregator } from './base'
import type { Model, Task, WeightsContainer, client } from '..'
import type { Model, WeightsContainer, client } from '..'
import { aggregation } from '..'

/**
Expand All @@ -16,15 +16,11 @@ import { aggregation } from '..'
export class SecureAggregator extends Aggregator<WeightsContainer> {
public static readonly MAX_SEED: number = 2 ** 47

private readonly maxShareValue: number

constructor (
task: Task,
model?: Model
model?: Model,
private readonly maxShareValue = 100
) {
super(task, model, 0, 2)

this.maxShareValue = this.task.trainingInformation.maxShareValue ?? 100
super(model, 0, 2)
}

aggregate (): void {
Expand Down
15 changes: 1 addition & 14 deletions discojs/discojs-core/src/async_informant.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type AggregatorBase } from './aggregator'
import type { AggregatorBase } from './aggregator'

export class AsyncInformant<T> {
private _round = 0
Expand All @@ -11,8 +11,6 @@ export class AsyncInformant<T> {
) {}

update (): void {
console.debug('before:')
this.printAllInfos()
if (this.round === 0 || this.round < this.aggregator.round) {
this._round = this.aggregator.round
this._currentNumberOfParticipants = this.aggregator.size
Expand All @@ -21,8 +19,6 @@ export class AsyncInformant<T> {
} else {
this._round = this.aggregator.round
}
console.debug('after:')
this.printAllInfos()
}

// Getter functions
Expand Down Expand Up @@ -52,13 +48,4 @@ export class AsyncInformant<T> {
averageNumberOfParticipants: this.averageNumberOfParticipants
}
}

// Debug
public printAllInfos (): void {
console.debug('task:', this.aggregator.task.id)
console.debug('round:', this.round)
console.debug('participants:', this.currentNumberOfParticipants)
console.debug('total:', this.totalNumberOfParticipants)
console.debug('average:', this.averageNumberOfParticipants)
}
}
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/training/disco.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export class Disco {
options.scheme = TrainingSchemes[task.trainingInformation.scheme as keyof typeof TrainingSchemes]
}
if (options.aggregator === undefined) {
options.aggregator = new MeanAggregator(task)
options.aggregator = new MeanAggregator()
}
if (options.client === undefined) {
if (options.url === undefined) {
Expand Down
4 changes: 2 additions & 2 deletions discojs/discojs-core/src/validation/validator.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ describe('validator', () => {

const data = (await new NodeImageLoader(simplefaceMock)
.loadAll(files.flat(), { labels })).train
const meanAggregator = new aggregator.MeanAggregator(simplefaceMock)
const meanAggregator = new aggregator.MeanAggregator()
const client = new clients.Local(new URL('http://localhost:8080'), simplefaceMock, meanAggregator)
meanAggregator.setModel(await client.getLatestModel())
const validator = new Validator(
Expand Down Expand Up @@ -66,7 +66,7 @@ describe('validator', () => {
labels: titanicTask.trainingInformation.outputColumns,
shuffle: false
})).train
const meanAggregator = new aggregator.MeanAggregator(titanicTask)
const meanAggregator = new aggregator.MeanAggregator()
const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, meanAggregator)
meanAggregator.setModel(await client.getLatestModel())
const validator = new Validator(titanicTask, new ConsoleLogger(), new EmptyMemory(), undefined, client)
Expand Down
Loading

0 comments on commit dbdb764

Please sign in to comment.