Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix decentralized learning fail #708

Merged
merged 19 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,19 @@ async function runUser(
url: URL,
data: data.DataSplit,
): Promise<List<RoundLogs>> {
const client = new clients.federated.FederatedClient(
url,
task,
new aggregators.MeanAggregator(),
);

// force the federated scheme
const disco = new Disco(task, { scheme: "federated", client });
const trainingScheme = task.trainingInformation.scheme
const aggregator = aggregators.getAggregator(task)
const client = clients.getClient(trainingScheme, url, task, aggregator)
const disco = new Disco(task, { scheme: trainingScheme, client });

const logs = List(await arrayFromAsync(disco.trainByRound(data)));
await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish
await disco.close();
return logs;
}

async function main (task: Task, numberOfUsers: number): Promise<void> {
console.log(`Started federated training of ${task.id}`)
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })
const [server, url] = await startServer()

Expand Down
2 changes: 1 addition & 1 deletion discojs/src/aggregator.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { NodeID } from "./client/types.js";

const AGGREGATORS: Set<[name: string, new () => Aggregator]> = Set.of<
new (model?: Model) => Aggregator
>(MeanAggregator, SecureAggregator).map((Aggregator) => [
>(MeanAggregator, SecureAggregator).map((Aggregator) => [ // MeanAggregator waits for 100% of the node's contributions by default
Aggregator.name,
Aggregator,
]);
Expand Down
7 changes: 4 additions & 3 deletions discojs/src/aggregator/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ export abstract class Base<T> {
log (step: AggregationStep, from?: client.NodeID): void {
switch (step) {
case AggregationStep.ADD:
console.log(`> Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`)
console.log(`Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`)
break
case AggregationStep.UPDATE:
if (from === undefined) {
Expand All @@ -139,8 +139,8 @@ export abstract class Base<T> {
}

/**
* Sets the aggregator's TF.js model.
* @param model The new TF.js model
* Sets the aggregator's model.
* @param model The new model
*/
setModel (model: Model): void {
this._model = model
Expand All @@ -151,6 +151,7 @@ export abstract class Base<T> {
* peer/client within the network, whom we are communicating with during this aggregation
* round.
* @param nodeId The node to be added
* @returns True is the node wasn't already in the list of nodes, False if already included
*/
registerNode (nodeId: client.NodeID): boolean {
if (!this.nodes.has(nodeId)) {
Expand Down
70 changes: 48 additions & 22 deletions discojs/src/aggregator/get.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,59 @@
import type { Task } from '../index.js'
import { aggregator } from '../index.js'
import type { Model } from "../index.js";

/**
* Enumeration of the available types of aggregator.
*/
export enum AggregatorChoice {
MEAN,
SECURE,
BANDIT
}
type AggregatorOptions = Partial<{
model: Model,
scheme: Task['trainingInformation']['scheme'], // if undefined, fallback on task.trainingInformation.scheme
roundCutOff: number, // MeanAggregator
threshold: number, // MeanAggregator
thresholdType: 'relative' | 'absolute', // MeanAggregator
}>

/**
* Provides the aggregator object adequate to the given task.
* @param task The task
* Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
* Here is the ordered list of parameters used to define the aggregator and its default behavior:
* task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
*
* If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
* Otherwise, we default to a MeanAggregator for both training schemes.
*
* For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
* Unless specified otherwise, for federated learning or local training the aggregator default to waiting
* for a single contribution to trigger a model update.
* (the server's model update for federated learning or our own contribution if training locally)
* For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update.
*
* @param task The task object associated with the current training session
* @param options Options passed down to the aggregator's constructor
* @returns The aggregator
*/
export function getAggregator (task: Task): aggregator.Aggregator {
const error = new Error('not implemented')
switch (task.trainingInformation.aggregator) {
case AggregatorChoice.MEAN:
return new aggregator.MeanAggregator()
case AggregatorChoice.BANDIT:
throw error
case AggregatorChoice.SECURE:
if (task.trainingInformation.scheme !== 'decentralized') {
export function getAggregator(task: Task, options: AggregatorOptions = {}): aggregator.Aggregator {
const aggregatorType = task.trainingInformation.aggregator ?? 'mean'
const scheme = options.scheme ?? task.trainingInformation.scheme

switch (aggregatorType) {
case 'mean':
if (scheme === 'decentralized') {
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
options = {
model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'relative',
...options
}
} else {
// If scheme == 'federated' then we only expect the server's contribution at each round
// so we set the aggregation threshold to 1 contribution
// If scheme == 'local' then we only expect our own contribution
options = {
model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'absolute',
...options
}
}
return new aggregator.MeanAggregator(options.model, options.roundCutOff, options.threshold, options.thresholdType)
case 'secure':
if (scheme !== 'decentralized') {
throw new Error('secure aggregation is currently supported for decentralized only')
}
return new aggregator.SecureAggregator()
default:
return new aggregator.MeanAggregator()
return new aggregator.SecureAggregator(options.model, task.trainingInformation.maxShareValue)
}
}
2 changes: 1 addition & 1 deletion discojs/src/aggregator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ export { Base as AggregatorBase, AggregationStep } from './base.js'
export { MeanAggregator } from './mean.js'
export { SecureAggregator } from './secure.js'

export { getAggregator, AggregatorChoice } from './get.js'
export { getAggregator } from './get.js'

export type Aggregator = Base<WeightsContainer>
81 changes: 66 additions & 15 deletions discojs/src/aggregator/mean.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,81 @@ import { AggregationStep, Base as Aggregator } from "./base.js";
import type { Model, WeightsContainer, client } from "../index.js";
import { aggregation } from "../index.js";

/** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
type ThresholdType = 'relative' | 'absolute'

/**
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
*
*/
export class MeanAggregator extends Aggregator<WeightsContainer> {
readonly #threshold: number;
readonly #thresholdType: ThresholdType;

/**
* @param threshold - how many contributions for trigger an aggregation step.
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
* - absolute: t > 1, thus requiring t contributions
* Create a mean aggregator that averages all weight updates received when a specified threshold is met.
* By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
* only accepts contributions from the current round (drops contributions from previous rounds).
*
* @param threshold - how many contributions trigger an aggregation step.
* It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions.
* Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`.
* It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions
* Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update),
* set `threshold = 1` and `thresholdType = 'absolute'`
* @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1,
* If `threshold != 1` then the specified thresholdType is ignored and overwritten
* If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution
* if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions,
* @param roundCutoff - from how many past rounds do we still accept contributions.
* If 0 then only accept contributions from the current round,
* if 1 then the current round and the previous one, etc.
*/
// TODO no way to require a single contribution
constructor(model?: Model, roundCutoff = 0, threshold = 1) {
if (threshold <= 0) throw new Error("threshold must be striclty positive");
if (threshold > 1 && !Number.isInteger(threshold))
throw new Error("absolute thresholds must be integeral");

constructor(model?: Model, roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType) {
if (threshold <= 0) throw new Error("threshold must be strictly positive");
if (threshold > 1 && (!Number.isInteger(threshold)))
throw new Error("absolute thresholds must be integral");

super(model, roundCutoff, 1);
this.#threshold = threshold;

if (threshold < 1) {
// Throw exception if threshold and thresholdType are conflicting
if (thresholdType === 'absolute') {
throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`)
}
this.#thresholdType = 'relative'
}
else if (threshold > 1) {
// Throw exception if threshold and thresholdType are conflicting
if (thresholdType === 'relative') {
throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`)
}
this.#thresholdType = 'absolute'
}
// remaining case: threshold == 1
else {
// Print a warning regarding the default behavior when thresholdType is not specified
if (thresholdType === undefined) {
console.warn(
"[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
"To instead wait for a single contribution, set thresholdType = 'absolute'"
)
this.#thresholdType = 'relative'
} else {
this.#thresholdType = thresholdType
}
}
}

/** Checks whether the contributions buffer is full. */
override isFull(): boolean {
const actualThreshold =
this.#threshold <= 1
const thresholdValue =
this.#thresholdType == 'relative'
? this.#threshold * this.nodes.size
: this.#threshold;

return (this.contributions.get(0)?.size ?? 0) >= actualThreshold;
return (this.contributions.get(0)?.size ?? 0) >= thresholdValue;
}

override add(
Expand All @@ -42,8 +90,11 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
if (currentContributions !== 0)
throw new Error("only a single communication round");

if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
return false;
if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) {
if (!this.nodes.has(nodeId)) console.warn("Contribution rejected because node id is not registered")
if (!this.isWithinRoundCutoff(round)) console.warn(`Contribution rejected because round ${round} is not within round cutoff`)
return false;
}

this.log(
this.contributions.hasIn([0, nodeId])
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/aggregator/secure.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ describe("secret shares test", function () {
describe("secure aggregator", () => {
it("behaves as mean aggregator", async () => {
const secureNetwork = setupNetwork(SecureAggregator)
const meanNetwork = setupNetwork(MeanAggregator)
const meanNetwork = setupNetwork(MeanAggregator) // waits for 100% of the nodes' contributions by default

const meanResults = await communicate(
Map(
Expand Down
9 changes: 6 additions & 3 deletions discojs/src/client/base.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import axios from 'axios'
import type { Set } from 'immutable'

import type { Model, Task, WeightsContainer } from '../index.js'
import { serialization } from '../index.js'
Expand Down Expand Up @@ -85,8 +84,12 @@ export abstract class Base {
_round: number,
): Promise<void> {}

get nodes (): Set<NodeID> {
return this.aggregator.nodes
// Number of contributors to a collaborative session
// If decentralized, it should be the number of peers
// If federated, it should the number of participants excluding the server
// If local it should be 1
get nbOfParticipants(): number {
return this.aggregator.nodes.size // overriden by the federated client
}

get ownId(): NodeID {
Expand Down
Loading