Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing end-to-end test with Byzantine-Robust Aggregator for cifar10 task #557

Merged
merged 13 commits into from
May 4, 2023
Merged
14 changes: 7 additions & 7 deletions discojs/discojs-core/src/async_buffer.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,39 @@ const mockAggregateAndStoreWeights = async (_weights: Iterable<number>): Promise
describe('AsyncWeightBuffer tests', () => {
it('add weight update with old time stamp returns false', async () => {
const t0 = -1
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
assert.isFalse(await asyncWeightBuffer.add(id, weights[0], t0))
})
it('add weight update with recent time stamp returns true', async () => {
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
const t0 = Date.now()
assert.isTrue(await asyncWeightBuffer.add(id, weights[0], t0))
})
it('bufferIsFull returns false if it is not full', () => {
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
assert.isFalse(asyncWeightBuffer.bufferIsFull())
})
it('buffer adding with cutoff = 0', () => {
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
assert.isFalse(asyncWeightBuffer.isNotWithinRoundCutoff(0))
assert.isTrue(asyncWeightBuffer.isNotWithinRoundCutoff(-1))
})
it('buffer adding with different cutoff = 1', () => {
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 1)
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 1, () => {})
assert.isFalse(asyncWeightBuffer.isNotWithinRoundCutoff(0))
assert.isFalse(asyncWeightBuffer.isNotWithinRoundCutoff(-1))
assert.isTrue(asyncWeightBuffer.isNotWithinRoundCutoff(-2))
})
it('Adding enough updates to buffer launches aggregator and updates weights', async () => {
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
const t0 = Date.now()
await Promise.all(weights.map(async (w) => await asyncWeightBuffer.add(w.toString(), w, t0)))
expect(asyncWeightBuffer.buffer.size).equal(0)
expect(weights).eql(mockUpdatedWeights)
expect(asyncWeightBuffer.round).equal(1)
})
it('Testing two full cycles (adding x2 buffer capacity)', async () => {
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const asyncWeightBuffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
mockUpdatedWeights = []

const t0 = Date.now()
Expand Down
5 changes: 4 additions & 1 deletion discojs/discojs-core/src/async_buffer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ export class AsyncBuffer<T> {
public readonly taskID: TaskID,
private readonly bufferCapacity: number,
private readonly aggregateAndStoreWeights: (weights: Iterable<T>) => Promise<void>,
private readonly roundCutoff = 0
private readonly roundCutoff = 0,
private readonly onRoundEnd: (newRound: number) => void
) {
this.buffer = Map()
this.round = 0
Expand All @@ -47,6 +48,8 @@ export class AsyncBuffer<T> {
if (this.bufferIsFull()) {
await this.aggregateAndStoreWeights(this.buffer.values())

this.onRoundEnd(this.round)

this.round += 1
this.observer?.update()
this.buffer = Map()
Expand Down
8 changes: 4 additions & 4 deletions discojs/discojs-core/src/async_informant.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,27 @@ const mockAggregateAndStoreWeights = async (): Promise<void> => {}

describe('AsyncInformant tests', () => {
it('get correct round number', async () => {
const buffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const buffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
const informant = new AsyncInformant(buffer)
expect(informant.getCurrentRound()).to.eql(0)
await Promise.all(weights.map(async (w) => await buffer.add(w.toString(), w, Date.now())))
expect(informant.getCurrentRound()).to.eql(1)
})
it('get correct number of participants for last round', async () => {
const buffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const buffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
const informant = new AsyncInformant(buffer)
await Promise.all(weights.map(async (w) => await buffer.add(w.toString(), w, Date.now())))
expect(informant.getNumberOfParticipants()).to.eql(bufferCapacity)
})
it('get correct average number of participants', async () => {
const buffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const buffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
const informant = new AsyncInformant(buffer)
await Promise.all(weights.map(async (w) => await buffer.add(w.toString(), w, Date.now())))
await Promise.all(weights.map(async (w) => await buffer.add(w.toString(), w, Date.now())))
expect(informant.getAverageNumberOfParticipants()).to.eql(bufferCapacity)
})
it('get correct total number of participants', async () => {
const buffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights)
const buffer = new AsyncBuffer(taskId, bufferCapacity, mockAggregateAndStoreWeights, 0, () => {})
const informant = new AsyncInformant(buffer)
await Promise.all(weights.map(async (w) => await buffer.add(w.toString(), w, Date.now())))
await Promise.all(weights.map(async (w) => await buffer.add(w.toString(), w, Date.now())))
Expand Down
1 change: 1 addition & 0 deletions discojs/discojs-core/src/client/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ export abstract class Base {
abstract onRoundEndCommunication (
updatedWeights: WeightsContainer,
staleWeights: WeightsContainer,
updatedMomentum: WeightsContainer,
s314cy marked this conversation as resolved.
Show resolved Hide resolved
round: number,
trainingInformant: TrainingInformant
): Promise<WeightsContainer>
Expand Down
1 change: 1 addition & 0 deletions discojs/discojs-core/src/client/decentralized/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ export abstract class Base extends ClientBase {
async onRoundEndCommunication (
updatedWeights: WeightsContainer,
staleWeights: WeightsContainer,
_: WeightsContainer, // TODO: Implement Byzantine-Robust Aggregator in Decentralized setting
s314cy marked this conversation as resolved.
Show resolved Hide resolved
round: number,
trainingInformant: TrainingInformant
): Promise<WeightsContainer> {
Expand Down
40 changes: 30 additions & 10 deletions discojs/discojs-core/src/client/federated/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,20 @@ export class Client extends Base {
}

// It sends weights to the server
async postWeightsToServer (weights: WeightsContainer): Promise<void> {
const msg: messages.postWeightsToServer = {
type: type.postWeightsToServer,
weights: await serialization.weights.encode(weights),
async postToServer (weights?: WeightsContainer, momentum?: WeightsContainer): Promise<void> {
if (weights === undefined && momentum === undefined) {
throw new Error('Invalid data to send to the server')
}

const msg: messages.postToServer = {
type: type.postToServer,
weights: weights ? await serialization.weights.encode(weights) : undefined,
momentum: momentum ? await serialization.weights.encode(momentum) : undefined,
round: this.round
}

console.log(`${this.clientID} sending weights for round ${this.round}`)

this.sendMessage(msg)
}

Expand All @@ -103,6 +111,8 @@ export class Client extends Base {

const received = await waitMessageWithTimeout(this.server, type.latestServerRound, MAX_WAIT_PER_ROUND)

console.log(`${this.clientID} received round ${received.round}`)

this.serverRound = received.round
this.serverWeights = serialization.weights.decode(received.weights)

Expand Down Expand Up @@ -181,15 +191,25 @@ export class Client extends Base {
async onRoundEndCommunication (
updatedWeights: WeightsContainer,
staleWeights: WeightsContainer,
updatedMomentum: WeightsContainer,
_: number,
trainingInformant: informant.FederatedInformant
): Promise<WeightsContainer> {
const noisyWeights = privacy.addDifferentialPrivacy(
updatedWeights,
staleWeights,
this.task
)
await this.postWeightsToServer(noisyWeights)
if (this.task.trainingInformation.byzantineRobustAggregator !== undefined && this.task.trainingInformation.tauPercentile !== undefined) {
console.log('Sending momentum to server')

await this.postToServer(undefined, updatedMomentum)
} else {
console.log('Sending weights to server')

const noisyWeights = privacy.addDifferentialPrivacy(
updatedWeights,
staleWeights,
this.task
)

await this.postToServer(noisyWeights, undefined)
}

await this.pullServerStatistics(trainingInformant)

Expand Down
11 changes: 6 additions & 5 deletions discojs/discojs-core/src/client/federated/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { weights } from '../../serialization'
import { type, hasMessageType } from '../messages'

export type MessageFederated =
postWeightsToServer |
postToServer |
latestServerRound |
pullServerStatistics |
postMetadata |
Expand All @@ -15,9 +15,10 @@ export type MessageFederated =
export interface messageGeneral {
type: type
}
export interface postWeightsToServer {
type: type.postWeightsToServer
weights: weights.Encoded
export interface postToServer {
type: type.postToServer
weights?: weights.Encoded
momentum?: weights.Encoded
round: number
}
export interface latestServerRound {
Expand Down Expand Up @@ -54,7 +55,7 @@ export function isMessageFederated (o: unknown): o is MessageFederated {
switch (o.type) {
case type.clientConnected:
return true
case type.postWeightsToServer:
case type.postToServer:
return true
case type.latestServerRound:
return true
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/client/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export enum type {
PartialSums,

// federated
postWeightsToServer,
postToServer,
postMetadata,
getMetadataMap,
latestServerRound,
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/client/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Time to wait for the others in milliseconds.
export const MAX_WAIT_PER_ROUND = 10_000
export const MAX_WAIT_PER_ROUND = 15_000

export async function timeout (ms: number): Promise<never> {
return await new Promise((resolve, reject) => {
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/dataset/data/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ export abstract class Data {

abstract batch (): Data

abstract preprocess (): Data
abstract preprocess (): Promise<Data>
}
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/dataset/data/image_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ export class ImageData extends Data {
return new ImageData(newDataset, this.task, this.size)
}

preprocess (): Data {
async preprocess (): Promise<Data> {
let newDataset = this.dataset
const preprocessImage = getPreprocessImage(this.task)
newDataset = newDataset.map((x: tf.TensorContainer) => preprocessImage(x))
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/dataset/data/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ export { DataSplit } from './data_split'
export { Data } from './data'
export { ImageData } from './image_data'
export { TabularData } from './tabular_data'
export { ImagePreprocessing } from './preprocessing'
export { ImagePreprocessing, TabularPreprocessing } from './preprocessing'
s314cy marked this conversation as resolved.
Show resolved Hide resolved
43 changes: 41 additions & 2 deletions discojs/discojs-core/src/dataset/data/preprocessing.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import { tf, Task } from '../..'
import { Dataset } from '../dataset'

type PreprocessImage = (image: tf.TensorContainer) => tf.TensorContainer
type PreprocessTabular = (dataset: Dataset) => Promise<Dataset>

export type Preprocessing = ImagePreprocessing
export type Preprocessing = ImagePreprocessing | TabularPreprocessing

export interface TabularTensorContainer extends tf.TensorContainerObject {
xs: number[]
ys: tf.Tensor1D | number | undefined
}

export interface ImageTensorContainer extends tf.TensorContainerObject {
xs: tf.Tensor3D | tf.Tensor4D
Expand All @@ -11,7 +18,11 @@ export interface ImageTensorContainer extends tf.TensorContainerObject {

export enum ImagePreprocessing {
Normalize = 'normalize',
Resize = 'resize'
Resize = 'resize',
}

export enum TabularPreprocessing {
Normalize = 'normalize',
}

export function getPreprocessImage (task: Task): PreprocessImage {
Expand All @@ -36,3 +47,31 @@ export function getPreprocessImage (task: Task): PreprocessImage {
}
return preprocessImage
}

export function getPreprocessTabular (task: Task): PreprocessTabular {
const preprocessTabular: PreprocessTabular = async (dataset: Dataset): Promise<Dataset> => {
// Dropping rows with null values
dataset = dataset.filter(row => (row as TabularTensorContainer).xs.every(el => el !== undefined))

const info = task.trainingInformation
if (info.preprocessingFunctions?.includes(TabularPreprocessing.Normalize)) {
// Creating a 2D tensor to compute mean and std on the whole dataset
const datasetAsArray = await dataset.map(row => (row as TabularTensorContainer).xs).toArray()
const datasetTensor2D = tf.tensor2d(datasetAsArray)

const dataMean = datasetTensor2D.mean(0).arraySync() as number[]
const diffFromMean = datasetTensor2D.sub(dataMean)

const squaredDiffFromMean = diffFromMean.square()
const variance = squaredDiffFromMean.mean(0)
const dataStd = variance.sqrt().arraySync() as number[]

// Standardizing the dataset
const newDataset = dataset.map(row => ({ xs: (row as TabularTensorContainer).xs.map((v, i) => (v - dataMean[i]) / dataStd[i]), ys: (row as TabularTensorContainer).ys }))
return newDataset
}

return dataset
}
return preprocessTabular
}
8 changes: 6 additions & 2 deletions discojs/discojs-core/src/dataset/data/tabular_data.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Task } from '../..'
import { getPreprocessTabular } from './preprocessing'
import { Dataset } from '../dataset'
import { Data } from './data'

Expand Down Expand Up @@ -28,7 +29,10 @@ export class TabularData extends Data {
return new TabularData(newDataset, this.task, this.size)
}

preprocess (): Data {
return this
async preprocess (): Promise<Data> {
let newDataset = this.dataset
const preprocessTabular = getPreprocessTabular(this.task)
newDataset = await preprocessTabular(newDataset)
return new TabularData(newDataset, this.task, this.size)
}
}
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/dataset/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export { Dataset } from './dataset'
export { DatasetBuilder } from './dataset_builder'
export { DataSplit, Data, TabularData, ImageData, ImagePreprocessing } from './data'
export { DataSplit, Data, TabularData, ImageData, ImagePreprocessing, TabularPreprocessing } from './data'
export { ImageLoader, TabularLoader, DataLoader } from './data_loader'
8 changes: 4 additions & 4 deletions discojs/discojs-core/src/default_tasks/cifar10.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { tf, Task, TaskProvider } from '..'
import { tf, Task, data, TaskProvider } from '..'

export const cifar10: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -28,9 +28,9 @@ export const cifar10: TaskProvider = {
metrics: ['accuracy']
},
dataType: 'image',
IMAGE_H: 32,
IMAGE_W: 32,
preprocessingFunctions: [],
preprocessingFunctions: [data.ImagePreprocessing.Resize],
IMAGE_H: 224,
IMAGE_W: 224,
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
scheme: 'Decentralized',
noiseScale: undefined,
Expand Down
11 changes: 5 additions & 6 deletions discojs/discojs-core/src/default_tasks/titanic.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { tf, Task, TaskProvider } from '..'
import { tf, Task, TaskProvider, data } from '..'

export const titanic: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -47,17 +47,16 @@ export const titanic: TaskProvider = {
modelID: 'titanic-model',
epochs: 20,
roundDuration: 10,
validationSplit: 0,
validationSplit: 0.2,
batchSize: 30,
preprocessingFunctions: [],
preprocessingFunctions: [data.TabularPreprocessing.Normalize],
modelCompileData: {
optimizer: 'rmsprop',
optimizer: 'sgd',
loss: 'binaryCrossentropy',
metrics: ['accuracy']
},
dataType: 'tabular',
inputColumns: [
'PassengerId',
'Age',
'SibSp',
'Parch',
Expand All @@ -79,7 +78,7 @@ export const titanic: TaskProvider = {

model.add(
tf.layers.dense({
inputShape: [6],
inputShape: [5],
units: 124,
activation: 'relu',
kernelInitializer: 'leCunNormal'
Expand Down
Loading