Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LUS COVID demo #660

Merged
merged 31 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f17537b
Add lus covid task to CLI, image_loader tests, fix node image loader …
JulienVig Apr 23, 2024
a5d338c
Fix label name one hot encoding mistake
JulienVig Apr 23, 2024
7393a86
Fix linting errors
JulienVig Apr 23, 2024
06eaaff
Fix WebSocket is not a constructor error when training collaboratively
JulienVig Apr 24, 2024
62f21ad
Improve lus_covid model architecture and hyperparamters
JulienVig Apr 24, 2024
b593360
Fix simple face loader test
JulienVig Apr 24, 2024
0b395a6
Download sample LUS images in datasets
JulienVig Apr 24, 2024
3acf6ae
Fix populate
JulienVig Apr 24, 2024
3270c61
Fix image_loader script
JulienVig Apr 24, 2024
9f3e16d
Fix small errors
JulienVig Apr 26, 2024
2356be5
Fix linting errors
JulienVig Apr 26, 2024
8123d0c
Improve lus_covid preprocessing
JulienVig Apr 28, 2024
3ced3af
Add validator and image loader tests for lus_covid
JulienVig Apr 28, 2024
2f3a573
Add an end-to-end lus covid training test
JulienVig Apr 28, 2024
b88fd4d
Fix image loader test mistake
JulienVig Apr 28, 2024
9be63fd
Make lus_covid federated
JulienVig Apr 28, 2024
44dc872
discojs-core/lus_covid: bump epochs to 50
tharvik Apr 29, 2024
ad828db
datasets: add covid lungs
tharvik Apr 22, 2024
efe85a2
web-client/vite: support node's Buffer
tharvik Apr 25, 2024
cfe7e11
discojs-core/trainer: show #participants
tharvik Apr 25, 2024
47e3c97
discojs-core/serialization/weights: lint
tharvik Apr 26, 2024
bc63290
discojs-core/dataset/builder: fw config
tharvik Apr 26, 2024
05abfd2
web-client/Trainer: inline disco
tharvik Apr 27, 2024
8311f2b
server/federated: avoid aggregator hell
tharvik Apr 25, 2024
90882f2
web-client/Trainer: reset logs
tharvik Apr 29, 2024
5886901
discojs-core/lus_covid: reduce epochs to 10
tharvik Apr 29, 2024
f7a9208
Fix duplications and typos
JulienVig Apr 29, 2024
7bde8b3
Bump lus_covid epochs to 50
JulienVig Apr 29, 2024
bd6843c
Increase lus_covid train test timeout
JulienVig Apr 29, 2024
d0a6237
Decrease lus_covid test epochs to 15
JulienVig May 1, 2024
f9a7cee
Merge develop
JulienVig May 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# dependencies
/node_modules/

# disco models
**/models
# built
dist/

Expand Down
5 changes: 4 additions & 1 deletion cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'

const unsafeArgs = parse<BenchmarkUnsafeArguments>(
{
task: { type: String, alias: 't', description: 'Task: titanic, simple_face or cifar10', defaultValue: 'simple_face' },
task: { type: String, alias: 't', description: 'Task: titanic, simple_face, cifar10 or lus_covid', defaultValue: 'simple_face' },
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 1 },
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration', defaultValue: 10 },
Expand All @@ -42,6 +42,7 @@ let supportedTasks: Map<string, Task> = Map()
supportedTasks = supportedTasks.set(defaultTasks.simpleFace.getTask().id, defaultTasks.simpleFace.getTask())
supportedTasks = supportedTasks.set(defaultTasks.titanic.getTask().id, defaultTasks.titanic.getTask())
supportedTasks = supportedTasks.set(defaultTasks.cifar10.getTask().id, defaultTasks.cifar10.getTask())
supportedTasks = supportedTasks.set(defaultTasks.lusCovid.getTask().id, defaultTasks.lusCovid.getTask())

const task = supportedTasks.get(unsafeArgs.task)
if (task === undefined) {
Expand All @@ -56,6 +57,8 @@ if (task.trainingInformation !== undefined) {
// For DP
// TASK.trainingInformation.clippingRadius = 10000000
// TASK.trainingInformation.noiseScale = 0
} else {
throw new Error("Task training information is undefined")
}

export const args: BenchmarkArguments = { ...unsafeArgs, task }
1 change: 0 additions & 1 deletion cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ async function runUser(
async function main (task: Task, numberOfUsers: number): Promise<void> {
console.log(`Started federated training of ${task.id}`)
console.log({ args })

const [server, url] = await startServer()

const data = await getTaskData(task)
Expand Down
60 changes: 30 additions & 30 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,51 +1,49 @@
import { Range } from 'immutable'
import fs from 'node:fs'
import fs_promises from 'fs/promises'
import fs from 'node:fs/promises'
import path from 'node:path'

import type { Task, data } from '@epfml/discojs-core'
import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node'

function filesFromFolder (dir: string, folder: string, fractionToKeep: number): string[] {
const f = fs.readdirSync(dir + folder)
return f.slice(0, Math.round(f.length * fractionToKeep)).map(file => dir + folder + '/' + file)
}

async function simplefaceData (task: Task): Promise<data.DataSplit> {
const dir = '../datasets/simple_face/'
const youngFolders = ['child']
const oldFolders = ['adult']

// const dir = '../../face_age/'
// const youngFolders = ['007', '008', '009', '010', '011', '012', '013', '014']
// const oldFolders = ['021', '022', '023', '024', '025', '026']

// TODO: we just keep x% of data for faster training, e.g., for each folder, we keep 0.1 fraction of images
const fractionToKeep = 1
const youngFiles = youngFolders.flatMap(folder => {
return filesFromFolder(dir, folder, fractionToKeep)
})

const oldFiles = oldFolders.flatMap(folder => {
return filesFromFolder(dir, folder, fractionToKeep)
})

const filesPerFolder = [youngFiles, oldFiles]
const youngFolder = dir + 'child/'
const adultFolder = dir + 'adult/'

const labels = filesPerFolder.flatMap((files, index) => Array<string>(files.length).fill(`${index}`))
const files = filesPerFolder.flat()
const youngFiles = (await fs.readdir(youngFolder)).map(file => path.join(youngFolder, file))
const adultFiles = (await fs.readdir(adultFolder)).map(file => path.join(adultFolder, file))
const images = youngFiles.concat(adultFiles)

return await new NodeImageLoader(task).loadAll(files, { labels })
const youngLabels = youngFiles.map(_ => 'child')
const oldLabels = adultFiles.map(_ => 'adult')
const labels = youngLabels.concat(oldLabels)
return await new NodeImageLoader(task).loadAll(images, { labels })
}

async function cifar10Data (cifar10: Task): Promise<data.DataSplit> {
const dir = '../datasets/CIFAR10/'
const files = (await fs_promises.readdir(dir)).map((file) => path.join(dir, file))
const files = (await fs.readdir(dir)).map((file) => path.join(dir, file))
const labels = Range(0, 24).map((label) => (label % 10).toString()).toArray()

return await new NodeImageLoader(cifar10).loadAll(files, { labels })
}

async function lusCovidData (lusCovid: Task): Promise<data.DataSplit> {
const dir = '../datasets/lus_covid/'
const covid_pos = dir + 'COVID+'
const covid_neg = dir + 'COVID-'
const files_pos = (await fs.readdir(covid_pos)).map(file => path.join(covid_pos, file))
const label_pos = Range(0, files_pos.length).map(_ => 'COVID-Positive')

const files_neg = (await fs.readdir(covid_neg)).map(file => path.join(covid_neg, file))
const label_neg = Range(0, files_neg.length).map(_ => 'COVID-Negative')

const files = files_pos.concat(files_neg)
const labels = label_pos.concat(label_neg).toArray()

const dataConfig = { labels, shuffle: true, validationSplit: 0.1, channels: 3 }
return await new NodeImageLoader(lusCovid).loadAll(files, dataConfig)
}

async function titanicData (titanic: Task): Promise<data.DataSplit> {
const dir = '../datasets/titanic_train.csv'

Expand All @@ -68,6 +66,8 @@ export async function getTaskData (task: Task): Promise<data.DataSplit> {
return await titanicData(task)
case 'cifar10':
return await cifar10Data(task)
case 'lus_covid':
return await lusCovidData(task)
case 'YOUR CUSTOM TASK HERE':
throw new Error('YOUR CUSTOM FUNCTION HERE')
default:
Expand Down
3 changes: 3 additions & 0 deletions datasets/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@

# wikitext
/wikitext/

# LUS Covid
/lus_covid/
8 changes: 8 additions & 0 deletions datasets/populate
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@ cd "$(dirname "$0")"
curl 'http://deai-313515.appspot.com.storage.googleapis.com/example_training_data.tar.gz' |
tar --extract --strip-components=1

# lungs ultrasound
curl 'https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly/download' > archive.zip
unzip -u archive.zip
rm -rf 'lungs ultrasound'
mv DeAI-testimages 'lus_covid'
rm archive.zip

# wikitext
mkdir -p wikitext
cd wikitext
curl 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz' |
tar --extract --gzip --strip-components=1
cd ..
10 changes: 5 additions & 5 deletions discojs/discojs-core/src/client/event_connection.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import isomorphic from 'isomorphic-ws'
import WebSocket from 'isomorphic-ws'
import type { Peer, SignalData } from './decentralized/peer.js'
import { type NodeID } from './types.js'
import msgpack from 'msgpack-lite'
Expand Down Expand Up @@ -82,19 +82,19 @@ export class PeerConnection extends EventEmitter<{ [K in type]: NarrowMessage<K>

export class WebSocketServer extends EventEmitter<{ [K in type]: NarrowMessage<K> }> implements EventConnection {
private constructor (
private readonly socket: isomorphic.WebSocket,
private readonly socket: WebSocket.WebSocket,
private readonly validateSent?: (msg: Message) => boolean
) { super() }

static async connect (url: URL,
validateReceived: (msg: unknown) => msg is Message,
validateSent: (msg: Message) => boolean): Promise<WebSocketServer> {
const ws = new isomorphic.WebSocket(url)
const ws = new WebSocket(url)
ws.binaryType = 'arraybuffer'

const server: WebSocketServer = new WebSocketServer(ws, validateSent)

ws.onmessage = (event: isomorphic.MessageEvent) => {
ws.onmessage = (event: WebSocket.MessageEvent) => {
if (!(event.data instanceof ArrayBuffer)) {
throw new Error('server did not send an ArrayBuffer')
}
Expand All @@ -110,7 +110,7 @@ export class WebSocketServer extends EventEmitter<{ [K in type]: NarrowMessage<K
}

return await new Promise((resolve, reject) => {
ws.onerror = (err: isomorphic.ErrorEvent) => {
ws.onerror = (err: WebSocket.ErrorEvent) => {
reject(new Error(`Server unreachable: ${err.message}`))
}
ws.onopen = () => { resolve(server) }
Expand Down
44 changes: 16 additions & 28 deletions discojs/discojs-core/src/dataset/data/image_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,25 @@ export class ImageData extends Data {
// cause an error during training, because of the lazy aspect of the dataset; we only
// verify the first sample.
if (task.trainingInformation.preprocessingFunctions?.includes(ImagePreprocessing.Resize) !== true) {
try {
const sample = (await dataset.take(1).toArray())[0]
// TODO: We suppose the presence of labels
// TODO: Typing (discojs-node/src/dataset/data_loader/image_loader.spec.ts)
if (!(typeof sample === 'object' && sample !== null)) {
throw new Error()
}
const sample = (await dataset.take(1).toArray())[0]
// TODO: We suppose the presence of labels
// TODO: Typing (discojs-node/src/dataset/data_loader/image_loader.spec.ts)
if (typeof sample !== 'object' || sample === null || sample === undefined) {
throw new Error("Image is undefined or is not an object")
}

let shape
if ('xs' in sample && 'ys' in sample) {
shape = (sample as { xs: tf.Tensor, ys: number[] }).xs.shape
} else {
shape = (sample as tf.Tensor3D).shape
}
if (!(
shape[0] === task.trainingInformation.IMAGE_W &&
shape[1] === task.trainingInformation.IMAGE_H
)) {
throw new Error()
}
} catch (e) {
let cause
if (e instanceof Error) {
cause = e
} else {
console.error("got invalid Error type", e)
}
throw new Error('Data input format is not compatible with the chosen task', { cause })
let shape
if ('xs' in sample && 'ys' in sample) {
shape = (sample as { xs: tf.Tensor, ys: number[] }).xs.shape
} else {
shape = (sample as tf.Tensor3D).shape
}
const {IMAGE_H, IMAGE_W} = task.trainingInformation
if (IMAGE_W !== undefined && IMAGE_H !== undefined &&
(shape[0] !== IMAGE_W || shape[1] !== IMAGE_H)) {
throw new Error(`Image doesn't have the dimensions specified in the task's training information. Expected ${IMAGE_H}x${IMAGE_W} but got ${shape[0]}x${shape[1]}.`)
}
}

return new ImageData(dataset, task, size)
}

Expand Down
10 changes: 9 additions & 1 deletion discojs/discojs-core/src/dataset/data_loader/data_loader.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import type { DataSplit, Dataset } from '../index.js'

export interface DataConfig { features?: string[], labels?: string[], shuffle?: boolean, validationSplit?: number, inference?: boolean }
export interface DataConfig {
features?: string[],
labels?: string[],
shuffle?: boolean,
validationSplit?: number,
inference?: boolean,
// Mostly used for reading lus_covid images with 3 channels (default is 1 and causes an error)
channels?:number
}

export abstract class DataLoader<Source> {
abstract load (source: Source, config: DataConfig): Promise<Dataset>
Expand Down
33 changes: 25 additions & 8 deletions discojs/discojs-core/src/dataset/data_loader/image_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ import { DataLoader } from '../data_loader/index.js'
* 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels.
*/
export abstract class ImageLoader<Source> extends DataLoader<Source> {
abstract readImageFrom (source: Source): Promise<tf.Tensor3D>
// We allow specifying the number of channels because the default number of channels
// differs between web and node for the same image
// E.g. lus covid images have 1 channel with fs.readFile but 3 when loaded with discojs-web
abstract readImageFrom (source: Source, channels?:number): Promise<tf.Tensor3D>

constructor (
private readonly task: Task
Expand All @@ -27,10 +30,10 @@ export abstract class ImageLoader<Source> extends DataLoader<Source> {
async load (image: Source, config?: DataConfig): Promise<Dataset> {
let tensorContainer: tf.TensorContainer
if (config?.labels === undefined) {
tensorContainer = await this.readImageFrom(image)
tensorContainer = await this.readImageFrom(image, config?.channels)
} else {
tensorContainer = {
xs: await this.readImageFrom(image),
xs: await this.readImageFrom(image, config?.channels),
ys: config.labels[0]
}
}
Expand All @@ -46,7 +49,7 @@ export abstract class ImageLoader<Source> extends DataLoader<Source> {

let index = 0
while (index < indices.length) {
const sample = await self.readImageFrom(images[indices[index]])
const sample = await self.readImageFrom(images[indices[index]], config?.channels)
const label = withLabels ? labels[indices[index]] : undefined
const value = withLabels ? { xs: sample, ys: label } : sample

Expand All @@ -66,12 +69,26 @@ export abstract class ImageLoader<Source> extends DataLoader<Source> {

const indices = Range(0, images.length).toArray()
if (config?.labels !== undefined) {
const numberOfClasses = this.task.trainingInformation?.LABEL_LIST?.length
if (numberOfClasses === undefined) {
throw new Error('wanted labels but none found in task')
const labelList = this.task.trainingInformation?.LABEL_LIST
if (labelList === undefined || !Array.isArray(labelList)) {
throw new Error('LABEL_LIST should be specified in the task training information')
}
const numberOfClasses = labelList.length
// Map label strings to integer
const label_to_int = new Map(labelList.map((label_name, idx) => [label_name, idx]))
if (label_to_int.size != numberOfClasses) {
throw new Error("Input labels aren't matching the task LABEL_LIST")
}

labels = config.labels.map(label_name => {
const label_int = label_to_int.get(label_name)
if (label_int === undefined) {
throw new Error(`Found input label ${label_name} not specified in task LABEL_LIST`)
}
return label_int
})

labels = tf.oneHot(tf.tensor1d(config.labels, 'int32'), numberOfClasses).arraySync() as number[]
labels = await tf.oneHot(tf.tensor1d(labels, 'int32'), numberOfClasses).array() as number[]
}
if (config?.shuffle === undefined || config?.shuffle) {
this.shuffle(indices)
Expand Down
10 changes: 5 additions & 5 deletions discojs/discojs-core/src/dataset/dataset_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export class DatasetBuilder<Source> {
/**
* Whether a dataset was already produced.
*/
// TODO useless, responsiblity on callers
// TODO useless, responsibility on callers
private _built: boolean

constructor (
Expand Down Expand Up @@ -84,12 +84,12 @@ export class DatasetBuilder<Source> {
}

private getLabels (): string[] {
// We need to duplicate the labels as we need one for each soure.
// We need to duplicate the labels as we need one for each source.
// Say for label A we have sources [img1, img2, img3], then we
// need labels [A, A, A].
let labels: string[][] = []
this.labelledSources.valueSeq().forEach((sources, index) => {
const sourcesLabels = Array.from({ length: sources.length }, (_) => index.toString())
this.labelledSources.forEach((sources, label) => {
const sourcesLabels = Array.from({ length: sources.length }, (_) => label)
labels = labels.concat(sourcesLabels)
})
return labels.flat()
Expand Down Expand Up @@ -128,7 +128,7 @@ export class DatasetBuilder<Source> {
shuffle: false
}
const sources = this.labelledSources.valueSeq().toArray().flat()
dataTuple = await this.dataLoader.loadAll(sources, defaultConfig)
dataTuple = await this.dataLoader.loadAll(sources, { ...defaultConfig, ...config })
}
// TODO @s314cy: Support .csv labels for image datasets (supervised training or testing)
this._built = true
Expand Down
Loading