Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates dependencies #655

Merged
merged 19 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions .github/workflows/lint-test-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -306,20 +306,13 @@ jobs:
- run: npm ci
- run: npm --workspace=./discojs/discojs-{core,web} run build
- run: npm --workspace=./web-client run test:unit
- uses: cypress-io/github-action@v6
with:
working-directory: ./web-client
install: false
component: true
env:
VUE_APP_SERVER_URL: http://server
- uses: cypress-io/github-action@v6
with:
working-directory: ./web-client
install: false
start: npm start
env:
VUE_APP_SERVER_URL: http://server
VITE_SERVER_URL: http://server

test-cli:
needs: [build-lib-core, build-lib-node, build-server, download-datasets]
Expand Down
2 changes: 1 addition & 1 deletion cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

The CLI lets one use DISCO in standalone manner (i.e. without running a server or browser backend manually). The CLI allows to conveniently simulate multiple clients and log metrics such as the training and validation accuracy of each client. Integration of DISCO into other apps can follow the same principles (no browser needed). Currently, the CLI only support running federated tasks. Since the CLI relies on Node.js, it uses DISCO through `discojs-node`.

For example, the following command trains a model on CIFAR10, using 4 federated clients for 15 epochs with a round duration of 5 batches (see [DISCOJS.md](../docs/DISCOJS.md#rounds) for more information on rounds)
For example, the following command trains a model on CIFAR10, using 4 federated clients for 15 epochs with a round duration of 5 epochs (see [DISCOJS.md](../docs/DISCOJS.md#rounds) for more information on rounds)

> [!NOTE]
> Make sure you first ran `./get_training_data.sh` (in the root folder) to download training data.
Expand Down
32 changes: 21 additions & 11 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import { Range } from 'immutable'
import { List, Range } from 'immutable'
import fs from 'node:fs/promises'

import type { TrainerLog, data, Task } from '@epfml/discojs-core'
import type { data, RoundLogs, Task } from '@epfml/discojs-core'
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs-core'
import { startServer } from '@epfml/disco-server'

import { saveLog } from './utils.js'
import { getTaskData } from './data.js'
import { args } from './args.js'

async function runUser (task: Task, url: URL, data: data.DataSplit): Promise<TrainerLog> {
const client = new clients.federated.FederatedClient(url, task, new aggregators.MeanAggregator())
async function runUser(
task: Task,
url: URL,
data: data.DataSplit,
): Promise<List<RoundLogs>> {
const client = new clients.federated.FederatedClient(
url,
task,
new aggregators.MeanAggregator(),
);

// force the federated scheme
const disco = new Disco(task, { scheme: 'federated', client })
const disco = new Disco(task, { scheme: "federated", client });

await disco.fit(data)
await disco.close()
return await disco.logs()
let logs = List<RoundLogs>();
for await (const round of disco.fit(data)) logs = logs.push(round);

await disco.close();
return logs;
}

async function main (task: Task, numberOfUsers: number): Promise<void> {
Expand All @@ -32,8 +42,8 @@ async function main (task: Task, numberOfUsers: number): Promise<void> {
)

if (args.save) {
const fileName = `${task.id}_${numberOfUsers}users.csv`
saveLog(logs, fileName)
const fileName = `${task.id}_${numberOfUsers}users.csv`;
await fs.writeFile(fileName, JSON.stringify(logs, null, 2));
}
console.log('Shutting down the server...')
await new Promise((resolve, reject) => {
Expand Down
8 changes: 0 additions & 8 deletions cli/src/utils.ts

This file was deleted.

18 changes: 7 additions & 11 deletions discojs/discojs-core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
"watch": "nodemon --ext ts --ignore dist --exec npm run",
"build": "tsc",
"lint": "npx eslint .",
"test": "mocha",
"docs": "typedoc ./src/index.ts --theme oxide"
"test": "mocha"
},
"repository": {
"type": "git",
Expand All @@ -21,27 +20,24 @@
"homepage": "https://github.com/epfml/disco#readme",
"dependencies": {
"@tensorflow/tfjs": "4",
"@types/msgpack-lite": "0.1",
"@xenova/transformers": "2",
"axios": "1",
"gpt3-tokenizer": "1",
"immutable": "4",
"isomorphic-wrtc": "1",
"isomorphic-ws": "4",
"isomorphic-ws": "5",
"msgpack-lite": "0.1",
"simple-peer": "9",
"tslib": "2",
"ws": "8"
},
"devDependencies": {
"@types/chai": "4",
"@types/mocha": "9",
"@types/mocha": "10",
"@types/msgpack-lite": "0.1",
"@types/simple-peer": "9",
"chai": "4",
"mocha": "9",
"chai": "5",
"mocha": "10",
"nodemon": "3",
"ts-node": "10",
"typedoc": "0.22",
"typedoc-theme-oxide": "0.1"
"ts-node": "10"
}
}
32 changes: 4 additions & 28 deletions discojs/discojs-core/src/client/base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import axios from 'axios'
import type { Set } from 'immutable'

import type { Model, Task, TrainingInformant, WeightsContainer } from '../index.js'
import type { Model, Task, WeightsContainer } from '../index.js'
import { serialization } from '../index.js'
import type { NodeID } from './types.js'
import type { EventConnection } from './event_connection.js'
Expand Down Expand Up @@ -65,55 +65,31 @@ export abstract class Base {
return await serialization.model.decode(new Uint8Array(response.data))
}

/**
* Communication callback called once at the beginning of the training instance.
* @param _weights The initial model weights
* @param _trainingInformant The training informant
*/
async onTrainBeginCommunication (
_weights: WeightsContainer,
_trainingInformant: TrainingInformant
): Promise<void> {}

/**
* Communication callback called once at the end of the training instance.
* @param _weights The final model weights
* @param _trainingInformant The training informant
*/
async onTrainEndCommunication (
_weights: WeightsContainer,
_trainingInformant: TrainingInformant
): Promise<void> {}

/**
* Communication callback called at the beginning of every training round.
* @param _weights The most recent local weight updates
* @param _round The current training round
* @param _trainingInformant The training informant
*/
async onRoundBeginCommunication (
async onRoundBeginCommunication(
_weights: WeightsContainer,
_round: number,
_trainingInformant: TrainingInformant
): Promise<void> {}

/**
* Communication callback called the end of every training round.
* @param _weights The most recent local weight updates
* @param _round The current training round
* @param _trainingInformant The training informant
*/
async onRoundEndCommunication (
async onRoundEndCommunication(
_weights: WeightsContainer,
_round: number,
_trainingInformant: TrainingInformant
): Promise<void> {}

get nodes (): Set<NodeID> {
return this.aggregator.nodes
}

get ownId (): NodeID {
get ownId(): NodeID {
if (this._ownId === undefined) {
throw new Error('the node is not connected')
}
Expand Down
6 changes: 3 additions & 3 deletions discojs/discojs-core/src/client/decentralized/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ export class Base extends Client {
this.pool = new PeerPool(peerIdMsg.id)
}

disconnect (): Promise<void> {
async disconnect (): Promise<void> {
// Disconnect from peers
this.pool?.shutdown()
await this.pool?.shutdown()
this.pool = undefined

if (this.connections !== undefined) {
Expand All @@ -136,7 +136,7 @@ export class Base extends Client {
}

// Disconnect from server
this.server?.disconnect()
await this.server?.disconnect()
this._server = undefined
this._ownId = undefined

Expand Down
9 changes: 4 additions & 5 deletions discojs/discojs-core/src/client/decentralized/peer.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ describe('peer', function () {
let peer2: Peer

beforeEach(async () => {
peer1 = await Peer.init('1')
peer2 = await Peer.init('2', true)
peer1 = new Peer('1')
peer2 = new Peer('2', true)
const peers = Set.of(peer1, peer2)

peer1.on('signal', (signal) => { peer2.signal(signal) })
Expand All @@ -19,9 +19,8 @@ describe('peer', function () {
).toArray())
})

afterEach(() => {
peer1.destroy()
peer2.destroy()
afterEach(async () => {
await Promise.all([peer1.destroy(), peer2.destroy()])
})

it('can send and receives a message', async () => {
Expand Down
24 changes: 13 additions & 11 deletions discojs/discojs-core/src/client/decentralized/peer.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { List, Map, Range, Seq } from 'immutable'
import wrtc from 'isomorphic-wrtc'
import SimplePeer from 'simple-peer'

import type { NodeID } from '../types.js'
Expand Down Expand Up @@ -43,6 +44,8 @@ interface Events {
//
// see feross/simple-peer#393 for more info
export class Peer {
private readonly peer: SimplePeer.Instance

private bufferSize?: number

private sendCounter: MessageID = 0
Expand All @@ -53,16 +56,11 @@ export class Peer {
chunks: Map<ChunkID, Buffer>
}>()

private constructor (
constructor (
public readonly id: NodeID,
private readonly peer: SimplePeer.Instance
) {}

static async init (id: NodeID, initiator: boolean = false): Promise<Peer> {
return new Peer(
id,
new SimplePeer({ wrtc: (await import('isomorphic-wrtc')).default, initiator })
)
initiator: boolean = false
) {
this.peer = new SimplePeer({ wrtc, initiator })
}

send (msg: Buffer): void {
Expand Down Expand Up @@ -157,8 +155,12 @@ export class Peer {
)
}

destroy (): void {
this.peer.destroy()
async destroy (): Promise<void> {
return new Promise((resolve, reject) => {
this.peer.once('error', reject)
this.peer.once('close', resolve)
this.peer.destroy()
})
}

signal (signal: SignalData): void {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ describe('peer pool', function () {
))
})

afterEach(() => {
pools.forEach((p) => { p.shutdown() })
afterEach(async () => {
await Promise.all(pools.valueSeq().map((p) => p.shutdown()))
})

function mockServer (poolId: string): EventConnection {
Expand All @@ -36,7 +36,7 @@ describe('peer pool', function () {
},
on: (): void => {},
once: (): void => {},
disconnect: (): void => {}
disconnect: (): Promise<void> => Promise.resolve()
}
}

Expand Down
15 changes: 6 additions & 9 deletions discojs/discojs-core/src/client/decentralized/peer_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ export class PeerPool {
private readonly id: NodeID
) {}

shutdown (): void {
async shutdown (): Promise<void> {
console.info(`[${this.id}] shutdown their peers`)

this.peers.forEach((peer) => { peer.disconnect() })
await Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect()))
this.peers = Map()
}

Expand All @@ -43,12 +43,11 @@ export class PeerPool {

console.info(`[${this.id}] is connecting peers:`, peersToConnect.toJS())

const newPeers = Map(await Promise.all(
const newPeers = Map(
peersToConnect
.filter((id) => !this.peers.has(id))
.map(async (id) => [id, await Peer.init(id, id < this.id)] as [string, Peer])
.toArray()
))
.map((id) => [id, new Peer(id, id < this.id)] as [string, Peer])
)

console.info(`[${this.id}] asked to connect new peers:`, newPeers.keySeq().toJS())
const newPeersConnections = newPeers.map((peer) => new PeerConnection(this.id, peer, signallingServer))
Expand All @@ -58,9 +57,7 @@ export class PeerPool {

clientHandle(this.peers)

await Promise.all(
Array.from(newPeersConnections.values()).map(async (connection) => { await connection.connect() }))

await Promise.all(newPeersConnections.valueSeq().map((conn) => conn.connect()))
console.info(`[${this.id}] knowns connected peers:`, this.peers.keySeq().toJS())

return this.peers
Expand Down
14 changes: 9 additions & 5 deletions discojs/discojs-core/src/client/event_connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export interface EventConnection {
on: <K extends type>(type: K, handler: (event: NarrowMessage<K>) => void) => void
once: <K extends type>(type: K, handler: (event: NarrowMessage<K>) => void) => void
send: <T extends Message>(msg: T) => void
disconnect: () => void
disconnect: () => Promise<void>
}

export async function waitMessage<T extends type> (connection: EventConnection, type: T): Promise<NarrowMessage<T>> {
Expand Down Expand Up @@ -75,8 +75,8 @@ export class PeerConnection extends EventEmitter<{ [K in type]: NarrowMessage<K>
this.peer.send(msgpack.encode(msg))
}

disconnect (): void {
this.peer.destroy()
async disconnect (): Promise<void> {
await this.peer.destroy()
}
}

Expand Down Expand Up @@ -117,8 +117,12 @@ export class WebSocketServer extends EventEmitter<{ [K in type]: NarrowMessage<K
})
}

disconnect (): void {
this.socket.close()
disconnect (): Promise<void> {
return new Promise((resolve, reject) => {
this.socket.once('close', resolve)
this.socket.once('error', reject)
this.socket.close()
})
}

send (msg: Message): void {
Expand Down
Loading