diff --git a/discojs/src/aggregator/aggregator.ts b/discojs/src/aggregator/aggregator.ts index f21aff2aa..dc9f03712 100644 --- a/discojs/src/aggregator/aggregator.ts +++ b/discojs/src/aggregator/aggregator.ts @@ -196,6 +196,19 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon this._nodes = this._nodes.delete(nodeId) } + /** + * Dispose the contributions to clean tensor memory + */ + dispose(): void { + this.contributions.forEach((roundContributions) => { + roundContributions.forEach((contribution) => { + contribution.dispose() + }) + }) + + this.contributions = Map() + } + /** * Overwrites the current set of active nodes with the given one. A node represents * an active neighbor peer/client within the network, whom we are communicating with diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 9e1298b93..9999acc5f 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -14,6 +14,7 @@ import type { EventConnection } from './event_connection.js' import type { Aggregator } from '../aggregator/index.js' import { EventEmitter } from '../utils/event_emitter.js' import { type } from "./messages.js"; +import { ModelWeightAccess } from "../training/disco.js"; const debug = createDebug("discojs:client"); @@ -24,6 +25,7 @@ const debug = createDebug("discojs:client"); export abstract class Client extends EventEmitter<{ status: RoundStatus; participants: number; + modelSynced: WeightsContainer | undefined; }> { // Own ID provided by the network's server. protected _ownId?: NodeID @@ -38,6 +40,9 @@ export abstract class Client extends EventEmitter<{ */ protected promiseForMoreParticipants: Promise | undefined = undefined; + // Interface to access trainer's model weights + protected modelWeightAccess?: ModelWeightAccess; + /** * When the server notifies the client that they can resume training * after waiting for more participants, we want to be able to display what @@ -56,6 +61,15 @@ export abstract class Client extends EventEmitter<{ ) { super() } + + /** + * Used for decentralized learning. + * Set the interface used by client to access to trainer's model weights. + * Disco object provides this access. + */ + setModelWeightAccess(modelWeightAccess: ModelWeightAccess){ + this.modelWeightAccess = modelWeightAccess + } /** * Communication callback called at the beginning of every training round. @@ -193,6 +207,10 @@ export abstract class Client extends EventEmitter<{ return await serialization.model.decode(encoded) } + public finishRound(): void{ + // DecentralizedClient override the method to clean up round state + } + /** * Number of contributors to a collaborative session * If decentralized, it should be the number of peers diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 6f9da6e77..b616f99f8 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -5,7 +5,7 @@ import type { DataType, Model, WeightsContainer } from "../../index.js"; import { serialization } from "../../index.js"; import { Client, shortenId } from '../client.js' import { type NodeID } from '../index.js' -import { type, type ClientConnected } from '../messages.js' +import { type, type ClientConnected, NarrowMessage } from '../messages.js' import { timeout } from '../utils.js' import { WebSocketServer, waitMessage, type PeerConnection, waitMessageWithTimeout } from '../event_connection.js' import { PeerPool } from './peer_pool.js' @@ -26,6 +26,14 @@ export class DecentralizedClient extends Client<"decentralized"> { #pool?: PeerPool #connections?: Map + // Flag if this model requires model synchronization + #modelSyncNeeded?: boolean + + // Check if the training round is in progress + // Used to get the latest model for model synchronization + #roundFinishedPromise?: Promise + #resolveRoundFinished?: () => void // contains resolver + // Used to handle timeouts and promise resolving after calling disconnect private get isDisconnected() : boolean { return this._server === undefined @@ -36,6 +44,26 @@ export class DecentralizedClient extends Client<"decentralized"> { // Emits the `participants` event this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size } + + // Used by model provider peer during model syncing + private async handleSignalNewPeer(event: NarrowMessage): Promise { + if (this.#pool === undefined){ + throw new Error('received signal about new peer but peer pool is undefined') + } + const roundFinishedPromise = this.#roundFinishedPromise + const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) + + const newcomerConn = syncConnection.get(event.newNode) + + if (newcomerConn === undefined){ + // if connection with newly joining client fails, print debug message + // and return + debug(`Cannot connect to newly joined client [${event.newNode}]`) + return + } + + await this.sendModel(newcomerConn, roundFinishedPromise) + } /** * Public method called by disco.ts when starting training. This method sends @@ -69,6 +97,13 @@ export class DecentralizedClient extends Client<"decentralized"> { this.#pool.signal(event.peer, event.signal) }) + // Listen if the client is selected as a model provider node for a newly joining client. + // Upon receiving the signal, this client establishes a connection with the newcomer + // and sends the latest model weights. + this.server.on(type.SignalNewPeer, (event) => { + void this.handleSignalNewPeer(event) + }) + // c.f. setupServerCallbacks doc for explanation let receivedEnoughParticipants = false this.setupServerCallbacks(() => receivedEnoughParticipants = true) @@ -79,8 +114,9 @@ export class DecentralizedClient extends Client<"decentralized"> { this.server.send(msg) const { id, waitForMoreParticipants, - nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) - + nbOfParticipants, joinedMidTraining } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) + + this.#modelSyncNeeded = joinedMidTraining this.nbOfParticipants = nbOfParticipants @@ -129,14 +165,50 @@ export class DecentralizedClient extends Client<"decentralized"> { * When connected, one peer creates a promise for every other peer's weight update * and waits for it to resolve. * + * If a client joined the training after the first round, + * model syncing happens first to get the latest model. */ override async onRoundBeginCommunication(): Promise { + if (this.#modelSyncNeeded) { + // 1. If model sync is needed, send server a request + this.server.send({ type: type.ModelSyncRequest }) + + // 2. Get the provider information from the server + const providerInfo = await waitMessageWithTimeout(this.server, type.SignalModelProvider, 30_000, "Timeout while waiting for the latest model provider") + + if (this.#pool === undefined) { + throw new Error('peer pool is undefined, make sure to call `client.connect()` first') + } + + // 3. Connect with model provider client and get the latest model + const syncConnection = await this.#pool.getPeers( + Set([providerInfo.providerNode]), + this.server, + ()=>{} + ) + const providerConn = syncConnection.get(providerInfo.providerNode) + + if (providerConn === undefined){ + throw new Error("The latest model provider is not connected") + } + + const latestModel = await this.receiveModel(providerConn) + this.modelWeightAccess?.setModelWeight(latestModel) + + this.emit("modelSynced", this.modelWeightAccess?.getModelWeight()) + this.#modelSyncNeeded = false + } + // Notify the server we want to join the next round so that the server // waits for us to be ready before sending the list of peers for the round this.server.send({ type: type.JoinRound }) // Store the promise for the current round's aggregation result. // We will await for it to resolve at the end of the round when exchanging weight updates. this.aggregationResult = this.aggregator.getPromiseForAggregation() + + // Do not proceed to local training when minNbOfParticipants condition is not satisfied + await this.waitForParticipantsIfNeeded() + this.saveAndEmit("local training") return Promise.resolve() } @@ -149,11 +221,55 @@ export class DecentralizedClient extends Client<"decentralized"> { // Once enough new participants join we can display the previous status again this.saveAndEmit("connecting to peers") // First we check if we are waiting for more participants before sending our weight update - await this.waitForParticipantsIfNeeded() - // Create peer-to-peer connections with all peers for the round - await this.establishPeerConnections() + + while(true){ + // Wait until enough participants are available before continuing the round + // Checks minNbOfParticipants requirement for + // when participants disconnect when connection error happens continuously + await this.waitForParticipantsIfNeeded() + + // Create peer-to-peer connections with all peers for the round + await this.establishPeerConnections() + + // Wait for connection related messages from the server before exchanging weight updates + // (1) If the client receives a StartWeightSharing message, it proceeds to weight update exchange + // (2) If it receives a RetryPeerConnections message, it retries peer connection establishment + // (3) After multiple retires, if the connection is still unsuccessful, the server starts excluding nodes from the round + // and sends a ConnectionFail message to those nodes + // (4) Upon receiving ConnectionFail, the client disconnects from the server + const msg = await Promise.race([ + waitMessage(this.server, type.StartWeightSharing), + waitMessage(this.server, type.RetryPeerConnections), + waitMessage(this.server, type.ConnectionFail), + ]) + + if (msg.type === type.StartWeightSharing){ + // Generate a promise that resolves when round training finishes + if (this.#roundFinishedPromise === undefined){ + this.#roundFinishedPromise = new Promise((resolve) => { + this.#resolveRoundFinished = resolve + }) + } + break + } else if (msg.type === type.RetryPeerConnections){ + debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`) + // clear the communication round peer pool + await this.#pool?.shutdown() + this.#pool = new PeerPool(this.ownId) + // clear the connections + this.#connections = Map() + this.setAggregatorNodes(Set(this.ownId)) + continue + } else if (msg.type === type.ConnectionFail){ + debug(`[${shortenId(this.ownId)}] disconnect from the server`) + await this.disconnect() + throw new Error("Client disconnected after connection failure") + } + } // Exchange weight updates with peers and return aggregated weights - return await this.exchangeWeightUpdates(weights) + const aggregatedWeight = await this.exchangeWeightUpdates(weights) + + return aggregatedWeight } /** @@ -178,8 +294,9 @@ export class DecentralizedClient extends Client<"decentralized"> { try { debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); const receivedMessage = await waitMessage(this.server, type.PeersForRound) - + const peers = Set(receivedMessage.peers) + debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray()); if (this.ownId !== undefined && peers.has(this.ownId)) { throw new Error('received peer list contains our own id') @@ -198,7 +315,9 @@ export class DecentralizedClient extends Client<"decentralized"> { (conn) => this.receivePayloads(conn) ) - debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS()); + // Signal server that all connections with other peers in the round are established + this.server.send({ type: type.ConnectionsReady }); + debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS()); this.#connections = connections } catch (e) { debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e); @@ -303,4 +422,48 @@ export class DecentralizedClient extends Client<"decentralized"> { } return await this.aggregationResult } + + /** + * Receive model from the model provider. + */ + private async receiveModel(providerConn: PeerConnection): Promise{ + const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model") + + const decoded = serialization.weights.decode(message.model) + return decoded + } + + /** + * Send the latest available model to a newly joining client. + * If the current training round is in progress, wait until the round finishes + * and receive the latest aggregated model. + */ + private async sendModel(newcomerConn: PeerConnection, roundFinishedPromise: Promise | undefined): Promise { + // wait until the round finishes to get the latest model + if (roundFinishedPromise !== undefined){ + await roundFinishedPromise + } + + const model = this.modelWeightAccess?.getModelWeight() + + if (model === undefined){ + debug("Failed to get the latest model from model provider client") + return + } + const encoded = await serialization.weights.encode(model) + + const message: messages.SharedModel = { + type: type.SharedModel, + model: encoded + } + newcomerConn.send(message) + } + + // Resolve the round finished promise and reset related state + override finishRound(): void{ + // Mark round as finished so that model synchronization can proceed + this.#resolveRoundFinished?.() + this.#roundFinishedPromise = undefined + this.#resolveRoundFinished = undefined + } } diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index 626062ad4..1c56f45f8 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -12,6 +12,7 @@ export interface NewDecentralizedNodeInfo { id: NodeID waitForMoreParticipants: boolean nbOfParticipants: number + joinedMidTraining: boolean } // WebRTC signal to forward to other node @@ -38,6 +39,54 @@ export interface PeersForRound { aggregationRound: number } +// peer sends to server to signal all the connections to other peers +// are established +export interface ConnectionsReady { + type: type.ConnectionsReady +} + +// Server signals each peer to start weight update sharing +export interface StartWeightSharing { + type: type.StartWeightSharing +} + +// Server signals peers to reestablish peer connections +export interface RetryPeerConnections { + type: type.RetryPeerConnections + aggregationRound: number +} + +// Server signals a node that the connection with other peers failed +export interface ConnectionFail { + type: type.ConnectionFail +} + +// Nodes joining in the middle of the training send to server +// to request the latest model before starting local training +export interface ModelSyncRequest { + type: type.ModelSyncRequest +} + +// Server signals a node that shares the lastest model with node +// who joined in the middle of the training +export interface SignalNewPeer { + type: type.SignalNewPeer + newNode: NodeID +} + +// Server signals new node joining in the middle of the training +// about the model provider node +export interface SignalModelProvider { + type: type.SignalModelProvider + providerNode: NodeID +} + +// Sent by client to another client to share the latest model +export interface SharedModel { + type: type.SharedModel + model: serialization.Encoded +} + /// Phase 1 communication (between peers) export interface Payload { @@ -55,15 +104,24 @@ export type MessageFromServer = SignalForPeer | PeersForRound | WaitingForMoreParticipants | - EnoughParticipants + EnoughParticipants | + StartWeightSharing | + RetryPeerConnections | + ConnectionFail | + SignalModelProvider | + SignalNewPeer export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | - JoinRound + JoinRound | + ConnectionsReady | + ModelSyncRequest -export type PeerMessage = Payload +export type PeerMessage = + Payload | + SharedModel export function isMessageFromServer (o: unknown): o is MessageFromServer { if (!hasMessageType(o)) return false @@ -75,11 +133,17 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer { typeof o.waitForMoreParticipants === 'boolean' case type.SignalForPeer: return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + 'signal' in o case type.PeersForRound: return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID) + case type.SignalNewPeer: + return 'newNode' in o && isNodeID(o.newNode) case type.WaitingForMoreParticipants: case type.EnoughParticipants: + case type.StartWeightSharing: + case type.RetryPeerConnections: + case type.ConnectionFail: + case type.SignalModelProvider: return true } @@ -94,9 +158,11 @@ export function isMessageToServer (o: unknown): o is MessageToServer { return true case type.SignalForPeer: return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + 'signal' in o case type.JoinRound: case type.PeerIsReady: + case type.ConnectionsReady: + case type.ModelSyncRequest: return true } @@ -112,6 +178,10 @@ export function isPeerMessage (o: unknown): o is PeerMessage { 'peer' in o && isNodeID(o.peer) && 'payload' in o && serialization.isEncoded(o.payload) ) + case type.SharedModel: + return ( + 'model' in o && serialization.isEncoded(o.model) + ) } return false diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index f5b5f9bb4..eb6682694 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -24,8 +24,24 @@ export enum type { // Message forwarded by the server from a client to another client // to establish a peer-to-peer (WebRTC) connection SignalForPeer, + // Message sent by nodes to server to signal all connections are established + ConnectionsReady, + // Sent by the server to signal nodes proceed to weight update sharing + StartWeightSharing, + // Sent by the server to signal nodes reestablish connections + RetryPeerConnections, + // Sent by the server to signal that the node's connection was not successful + ConnectionFail, // The weight update Payload, + // Sent by nodes to the server to request the latest model + ModelSyncRequest, + // Sent by the server to nodes to share the provider node info + SignalModelProvider, + // Sent by the server to nodes who was selected as a model provider node + SignalNewPeer, + // Sent by node to node to share the latest model weights + SharedModel, /* Federated */ // The server answers the ClientConnected message with the necessary information diff --git a/discojs/src/default_tasks/cifar10.ts b/discojs/src/default_tasks/cifar10.ts index 3a3537440..de00ef706 100644 --- a/discojs/src/default_tasks/cifar10.ts +++ b/discojs/src/default_tasks/cifar10.ts @@ -47,7 +47,8 @@ export const cifar10: TaskProvider<"image", "decentralized"> = { }, minNbOfParticipants: 3, maxShareValue: 100, - tensorBackend: 'tfjs' + tensorBackend: 'tfjs', + maxConnectionRetry: 3, } }); }, diff --git a/discojs/src/default_tasks/mnist.ts b/discojs/src/default_tasks/mnist.ts index 71b148507..0be9cf1ae 100644 --- a/discojs/src/default_tasks/mnist.ts +++ b/discojs/src/default_tasks/mnist.ts @@ -43,7 +43,8 @@ export const mnist: TaskProvider<"image", "decentralized"> = { }, minNbOfParticipants: 3, maxShareValue: 100, - tensorBackend: 'tfjs' + tensorBackend: 'tfjs', + maxConnectionRetry: 3, } }); }, diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index fa01cdf2e..55ef123ed 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -100,6 +100,9 @@ export namespace TrainingInformation { .object({ scheme: z.literal("decentralized"), aggregationStrategy: z.literal(["byzantine", "mean", "secure"]), + + // Maximum number of retries for connection failures + maxConnectionRetry: z.number().nonnegative().int().default(3), }) .and(nonLocalNetworkSchema), federated: z diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 7dd019b2d..50c173680 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -16,6 +16,7 @@ import type { Network, Task, } from "../index.js"; +import { WeightsContainer } from "../index.js"; import type { Aggregator } from "../aggregator/index.js"; import { getAggregator } from "../aggregator/index.js"; import { enumerate, split } from "../utils/async_iterator.js"; @@ -70,6 +71,15 @@ function buildSummaryLog(roundNum: number, epochNum: number, roundLogs: RoundLog } } +/** + * Interface providing an access to trainer's model weights. + * Used for model synchronization to retrieve and set the latest model. + */ +export interface ModelWeightAccess{ + getModelWeight(): WeightsContainer; + setModelWeight(weight: WeightsContainer): void; +} + /** * Top-level class handling distributed training from a client's perspective. It is meant to be * a convenient object providing a reduced yet complete API that wraps model training and @@ -77,7 +87,8 @@ function buildSummaryLog(roundNum: number, epochNum: number, roundLogs: RoundLog */ export class Disco extends EventEmitter<{ status: RoundStatus; - participants: number + participants: number; + modelSynced: WeightsContainer | undefined; }> { public readonly trainer: Trainer; readonly #client: clients.Client; @@ -127,9 +138,19 @@ export class Disco extends EventEmitter<{ this.#client = client; this.#task = task; this.trainer = new Trainer(task, client); + // Set ModelWeightAccess of the client + this.#client.setModelWeightAccess({ + getModelWeight: () => { + return new WeightsContainer(this.trainer.model.weights.weights.map(t => t.clone())); + }, + setModelWeight: (weights) => { + this.trainer.model.weights = weights; + } + }); // Simply propagate the training status events emitted by the client this.#client.on("status", (status) => this.emit("status", status)); this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants)); + this.#client.on("modelSynced", (latestWeights) => this.emit("modelSynced", latestWeights)); } /** Train on dataset, yielding logs of every round. */ diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index 68e716bcc..87dada9db 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -141,6 +141,7 @@ export class Trainer { // Update the local weights this.model.weights = networkWeights; + this.#client.finishRound(); } } diff --git a/docs/DECENTRALIZED.md b/docs/DECENTRALIZED.md new file mode 100644 index 000000000..fa3c74b87 --- /dev/null +++ b/docs/DECENTRALIZED.md @@ -0,0 +1,55 @@ +# Peer Connection in Decentralized Learning +This document describes how peer connections for decentralized learning are established. +Relevant code can be found in [decentralized controller](../server/src/controllers/decentralized_controller.ts) and [decentralized client](../discojs/src/client/decentralized/decentralized_client.ts). + +Peer connections for decentralized learning are coordinated by the server. However, model weight updates are shared only between peers. + +## Peer Connection and Round Participation + +Clients first connect to the server. The server then coordinates which clients should participate in each decentralized training round. + +At the beginning of each round, clients send `JoinRound` messages to the server. After that, they send `PeerIsReady` messages to notify the server that they are ready to establish peer connections and exchange model updates. + +For each round, the server assumes that all currently connected clients are participating, except for clients that are still syncing their model after joining in the middle of training. Once the server has received `PeerIsReady` messages from all expected participants for the round, it sends a `PeersForRound` message to each participant. This message tells clients which peers they should connect to for the current round. + +After receiving `PeersForRound`, participants establish peer connections and proceed with decentralized learning. + +### Connection Ready Check and Signaling Weight Sharing + +Before clients start sharing model weights, the server checks that all participants have successfully completed peer connection setup. This prevents faster clients from starting weight sharing while slower clients are still establishing connections. +The process is as follows: +1. After completing peer connections, each client sends a `ConnectionsReady` message to the server. +2. The server counts the number of received `ConnectionsReady` messages. +3. Once the number of `ConnectionsReady` messages matches the number of expected participants for the round, the server sends a `StartWeightSharing` message to all round participants. +4. Clients only start sharing model updates after receiving `StartWeightSharing`. + +### Connection Retries and Failed Client Disconnection + +Peer connections may fail, so decentralized training includes a retry mechanism. The maximum number of retries is controlled by `maxConnectionRetry`, which is specified in the task training information. + +The retry mechanism is triggered when the server times out while waiting for `ConnectionsReady` messages. +The process is as follows: +1. If the number of retries is still below `maxConnectionRetry`, the server sends a `RetryPeerConnection` message to all peers in the current round. +2. When clients receive `RetryPeerConnection`, they clean up their peer pool and aggregator nodes, then rerun the peer connection phase. +3. If the connection setup still fails after more than `maxConnectionRetry` attempts, the server removes the failed peers from the round peer list. +4. The server sends a `ConnectionFail` message to the failed peers. +5. When a client receives `ConnectionFail`, it disconnects from the server. +6. The remaining clients receive `RetryPeerConnection` and retry the peer connection phase without the failed peers. + +This allows the remaining participants to continue training even when one or more peers fail to establish connections. + +### Model Syncing for Participants Joining in the Middle of Training + +Participants that join in the middle of training need to receive the latest model before they can participate in future rounds. In decentralized learning, peers do not send model weights to the server, so the newcomer must request the latest model from an existing peer. + +The model syncing process is as follows: +1. When a new participant joins in the middle of training, the server marks the participant as having joined mid-training in `NewDecentralizedNodeInfo`. +2. After receiving `NewDecentralizedNodeInfo`, the new client sets a local flag indicating that it needs model syncing. +3. When training begins, if this flag is set, the new client sends a `ModelSyncRequest` message to the server. +4. After receiving `ModelSyncRequest`, the server sends messages as step 5 and 6, using selected model provider information from previous training round. +5. The server sends `SignalModelProvider` to the new participant with information about the provider peer. +6. The server sends `SignalNewPeer` to the provider peer with information about the newly joined peer. +7. Using this signaling information, the new participant and provider peer establish a peer connection. +8. The provider waits until the current aggregation round has finished, then sends the latest model to the new participant using a `SharedModel` message. +9. The new participant receives the model and updates its local model weights. +10. After syncing, the new participant can join subsequent decentralized training rounds. \ No newline at end of file diff --git a/docs/FEDERATED.md b/docs/FEDERATED.md new file mode 100644 index 000000000..347d19d1e --- /dev/null +++ b/docs/FEDERATED.md @@ -0,0 +1,19 @@ +# Connections and Aggregations in Federated Learning +This documentation describes how connections between the server and clients are established, how model updates are aggregated and how updated weights are distributed. + +## Connecting to the Server +Clients participating in federated learning connect directly to the server. The server acts as the central coordination and aggregation point. Therefore, the clients only need to establish a connection with the server. + +When a client connects, the server assigns it a client ID and sends it the latest available global model weights and training information. The client initializes its local model with these weights and can begin training on its local model. + +## Aggregating Model Updates +After finishing local training for a round, each client sends its model update to the server using a `SendPayload` message. This message contains the client's current round number so that the server can synchronize model weights aggregation. + +The server checks the weight update contribution and adds it to the aggregator. When the required number of contributions has been received, the aggregator combines the client updates according to the aggregation mode and produces a new global model update. + +The server then sends the aggregated result to each participants using a `ReceiveServerPayload` message. Each client updates its local model to the received weights and proceeds to the next training round. + +After every successful aggregation, the server also stores the resulting weights as the latest global model weights. + +## Clients Joining During Training +When a new client joins an ongoing training round, the server sends it the latest available global model weights. The new client can then begin local training from the latest globally aggregated model. \ No newline at end of file diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index 37ea428f4..ad984ab13 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -2,7 +2,7 @@ import createDebug from "debug"; import { v4 as randomUUID } from 'uuid' import * as msgpack from "@msgpack/msgpack"; import type WebSocket from 'ws' -import { Map } from 'immutable' +import { Map, Set } from 'immutable' import { client, DataType } from "@epfml/discojs"; @@ -21,7 +21,19 @@ export class DecentralizedController< // the node has already sent a PeerIsReady message) // We wait for all peers to be ready to exchange weight updates #roundPeers = Map() + #connectFinishedNodes = Map() #aggregationRound = 0 + #timeout?: NodeJS.Timeout + + // number of connection retries for the training round + #connectionRetry = 0 + + // Client selected to provide the latest model to peers + // joining in the middle of training + #providerNode?: client.NodeID + + // Set of nodes that are syncing node + #syncingNodes = Set() handle (ws: WebSocket): void { const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants @@ -45,13 +57,22 @@ export class DecentralizedController< case MessageTypes.ClientConnected: { debug(`peer [%s] joined ${this.task.id}`, shortId) this.connections = this.connections.set(peerId, ws) + + // If the new peer joins in the middle of training, + // it needs to get the latest model from an existing peer + let joinedMidTraining = false + if (this.#aggregationRound > 0){ + joinedMidTraining = true + this.#syncingNodes = this.#syncingNodes.add(peerId) + } // Answer with client id in an NewNodeInfo message const msg: messages.NewDecentralizedNodeInfo = { type: MessageTypes.NewDecentralizedNodeInfo, id: peerId, nbOfParticipants: this.connections.size, - waitForMoreParticipants: this.connections.size < minNbOfParticipants + waitForMoreParticipants: this.connections.size < minNbOfParticipants, + joinedMidTraining: joinedMidTraining, } ws.send(msgpack.encode(msg), { binary: true }) // Send an update to participants if we can start/resume training @@ -61,6 +82,7 @@ export class DecentralizedController< // Send by peers at the beginning of each training round to notify // the server that they want to join the round case MessageTypes.JoinRound: { + this.#syncingNodes = this.#syncingNodes.delete(peerId) this.#roundPeers = this.#roundPeers.set(peerId, false) break } @@ -84,6 +106,41 @@ export class DecentralizedController< this.connections.get(msg.peer)?.send(msgpack.encode(forward)) break } + case MessageTypes.ConnectionsReady: { + // Select the first client that finishes peer connections + // as the model provider for clients joined mid-training + const numconnFinishedNodes = this.#connectFinishedNodes.reduce((acc, val) => acc + (val ? 1 : 0), 0) + if (!numconnFinishedNodes){ + this.#providerNode = peerId + } + + this.#connectFinishedNodes = this.#connectFinishedNodes.set(peerId, true) + this.signalWeightSharing() + break + } + case MessageTypes.ModelSyncRequest: { + // Upon receiving a model sync request, send relevant client information + // to both the model provider and the newly joined client + if (!this.#providerNode) { + debug("There is no provider node to share the latest model") + break + } + + // Signal the newly joined client with the provider client's information + const providerInfo: messages.SignalModelProvider = { + type: MessageTypes.SignalModelProvider, + providerNode: this.#providerNode + } + this.connections.get(peerId)?.send(msgpack.encode(providerInfo)) + + // Signal the provider client with newly joined client's information + const newNodeInfo: messages.SignalNewPeer = { + type: MessageTypes.SignalNewPeer, + newNode: peerId + } + this.connections.get(this.#providerNode)?.send(msgpack.encode(newNodeInfo)) + break + } default: { const _: never = msg throw new Error('should never happen') @@ -98,13 +155,27 @@ export class DecentralizedController< // Remove the participant when the websocket is closed this.connections = this.connections.delete(peerId) this.#roundPeers = this.#roundPeers.delete(peerId) + this.#connectFinishedNodes = this.#connectFinishedNodes.delete(peerId) + + // Reset the training session when all participants leave + if (this.connections.size === 0){ + debug("All participants left. Resetting decentralized training session") + this.reset() + return + } + debug("client [%s] left", shortId) + // If this participant was a latest model provider node, + // replace the provider node to another node + if (this.#providerNode === peerId) { + this.#providerNode = this.connections.keySeq().first() + } + // Check if we are already waiting for new participants to join if (this.waitingForMoreParticipants) return // If no, check if we are still above the minimum number of participant required if (this.connections.size >= minNbOfParticipants) { - // Check if remaining peers are all ready to exchange weight updates this.sendPeersForRoundIfNeeded() return } @@ -113,6 +184,24 @@ export class DecentralizedController< this.sendWaitForMoreParticipantsMsg() }) } + + reset(): void { + this.resetConnectionState() + + this.#roundPeers = Map() + this.#connectFinishedNodes = Map() + this.#aggregationRound = 0 + this.#connectionRetry = 0 + this.#providerNode = undefined + this.#syncingNodes = Set() + + // Reset the timeout + if (this.#timeout !== undefined){ + clearTimeout(this.#timeout) + this.#timeout = undefined + } + } + /** * Check if we have enough participants to start the training * and if all peers that joined the round are ready to exchange weight updates @@ -121,10 +210,16 @@ export class DecentralizedController< private sendPeersForRoundIfNeeded(): void { const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants const nbOfPeersReady = this.#roundPeers.filter(ready => ready).size + // participating peers are connected peers expect the ones that are in process of syncing + const participatingPeers = this.connections.keySeq().toSet().subtract(this.#syncingNodes) + // First check if there are enough participants to start the round // Then check if all peers that wanted to join this round are ready + // All peers that are connected to the server (except for newly joining peers waiting for the latest model) + // are expected to participate in the round + if (nbOfPeersReady < minNbOfParticipants - || nbOfPeersReady != this.#roundPeers.size) return + || nbOfPeersReady != participatingPeers.size) return // Once every peer that joined the round is ready, we can start the round this.#roundPeers.keySeq() .map((id) => { @@ -145,9 +240,165 @@ export class DecentralizedController< } return [conn, encoded] as [WebSocket, Buffer] }).forEach(([conn, encoded]) => { conn.send(encoded) }) + + // Initialize connectFinishedNodes with all peers set to false + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + // Change the peer states to not ready + this.#roundPeers = this.#roundPeers.map(() => false) + + // Start timeout to check peer connections are successful + this.startTimeout() + } + + /** + * Check if all the participants of the round finished connecting + * with other peers in the round + * If so, send StartWeightSharing message to signal peers to proceed + */ + private signalWeightSharing(): void { + // Return if not all participants are ready + if (!this.#connectFinishedNodes.every((ready) => ready)) + return + + // Stop the timeout + this.clearTimeout() + + // Send round participants StartWeightSharing messages + this.#roundPeers.keySeq() + .map((id) => { + const startSignal = { + type: MessageTypes.StartWeightSharing, + } + debug("Signaling weight sharing to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(startSignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + // empty the list of peers for the next round this.#roundPeers = Map() + this.#connectFinishedNodes = Map() this.#aggregationRound++ } + + /** + * Set a timeout to check peer connections establishment + */ + private startTimeout(maxTime: number = 60_000): void { + this.#timeout = setTimeout(() => { + this.handleTimeout() + }, maxTime) + } + + /** + * Clear previously set timeout once all peer connections + * are established before the timeout + */ + private clearTimeout(): void { + if (this.#timeout !== undefined){ + clearTimeout(this.#timeout) + this.#timeout = undefined + } + } + + /** + * Called when a timeout occurs during peer connection + * Signals peers to discard existing connections and + * reestablish connections with the current set of peers + */ + private handleTimeout(): void { + this.clearTimeout() + debug(`Connection setup timeout for round ${this.#aggregationRound}, Retrying with same peers`) + // Increment the connection retry count + this.#connectionRetry += 1; + + // If the number of retries exceeds the threshold, exclude the failed peers from the round + // and retry peer connection only with the remaining peers + if (this.#connectionRetry >= this.task.trainingInformation.maxConnectionRetry){ + // Exclude the failed peers + this.#connectFinishedNodes.forEach((connected, nodeId) => { + if (!connected){ + // If the node failed connection, exclude from #roundPeers + this.#roundPeers = this.#roundPeers.delete(nodeId) + this.#connectFinishedNodes = this.#connectFinishedNodes.delete(nodeId) + // Signal the node that connection is failed for that node + const conn = this.connections.get(nodeId) + if (conn === undefined) { + throw new Error(`peer ${nodeId} marked as ready but not connection to it`) + } + const failSignal : messages.ConnectionFail = { + type: MessageTypes.ConnectionFail + } + const encoded = msgpack.encode(failSignal) + conn.send(encoded) + } + }) + + // Restart the round with remaining clients + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // Reset the ready and connection status of roundPeers + this.#roundPeers = this.#roundPeers.map(() => false) + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + + // Reset the connectionRetry since we excluded the failed clients + this.#connectionRetry = 0 + // Restart the timeout after sending retry messages + this.startTimeout() + return + } + + // Retry peer connection with original roundPeers + // Reset the ready and connection status of roundPeers + this.#roundPeers = this.#roundPeers.map(() => false) + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // Restart the timeout after sending retry messages + this.startTimeout() + } } diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index 1e52fe539..1d75f18dc 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -151,9 +151,8 @@ export class FederatedController extends TrainingController< // Reset the training session when all participants left if (this.connections.size === 0) { - debug("All participants left. Resetting the training session") - this.#aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') - this.#latestGlobalWeights = this.initialWeights + debug("All participants left. Resetting federated training session") + this.reset() } // Check if we dropped below the minimum number of participant required @@ -166,4 +165,18 @@ export class FederatedController extends TrainingController< this.sendWaitForMoreParticipantsMsg() }) } + + reset(): void { + this.resetConnectionState() + + this.#aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') + this.#latestGlobalWeights = this.initialWeights + + // Since we replaced aggregator, we also need to register new aggregation listener + this.#aggregator.on("aggregation", async (weightUpdate) => { + this.#latestGlobalWeights = await serialization.weights.encode(weightUpdate) + }) + + this.#aggregator.dispose() + } } diff --git a/server/src/controllers/training_controller.ts b/server/src/controllers/training_controller.ts index bbee678ec..bff46247d 100644 --- a/server/src/controllers/training_controller.ts +++ b/server/src/controllers/training_controller.ts @@ -44,6 +44,16 @@ export abstract class TrainingController< abstract handle( ws: WebSocket ): void + + // Reset the controller state + abstract reset(): void + + // Reset the peer connection state + // Used when the training is finished + protected resetConnectionState(): void { + this.waitingForMoreParticipants = true + this.connections = Map() + } /** * If enough participants joined, notifies them that the training can start/resume diff --git a/server/tests/e2e/decentralized.spec.ts b/server/tests/e2e/decentralized.spec.ts index 0a0fb29af..6ecf7f337 100644 --- a/server/tests/e2e/decentralized.spec.ts +++ b/server/tests/e2e/decentralized.spec.ts @@ -1,5 +1,5 @@ import type * as http from "node:http"; -import type { DataType, RoundStatus, Task, TaskProvider } from "@epfml/discojs"; +import type { DataType, RoundStatus, Task, TaskProvider, EpochLogs } from "@epfml/discojs"; import { aggregator as aggregators, client as clients, @@ -27,6 +27,46 @@ async function expectWSToBeClose( expect(l).to.be.closeTo(r, 1e-4); } +// function from federated.spec.ts +export async function arrayFromAsync(iter: AsyncIterable): Promise { + const ret: T[] = []; + for await (const e of iter) { + // TODO trick to allow other Promises to run + // else one client might progress alone without communicating with others + // will be fixed when client orchestrations in the server is correctly done + await new Promise((resolve) => setTimeout(resolve, 10)); + + ret.push(e); + } + return ret; +} + +// function to check if weights across all participants are close to each other +async function expectAllWSToBeClose( + weights: WeightsContainer[] +): Promise { + const reference = weights[0] + + await Promise.all( + weights.map(async (current) => { + await expectWSToBeClose(reference, current) + }) + ) +} + +const expectWeightsToEqual = ( + a: WeightsContainer, + b: WeightsContainer, +) => { + expect(a.weights.length).to.equal(b.weights.length); + + a.weights.forEach((w, i) => { + expect(Array.from(w.dataSync())).to.deep.equal( + Array.from(b.weights[i].dataSync()), + ); + }); +}; + describe("end-to-end decentralized", { timeout: 50_000 }, () => { let handle: http.Server | undefined; async function startServer( @@ -134,196 +174,444 @@ describe("end-to-end decentralized", { timeout: 50_000 }, () => { await reachConsensus(url, "secure", 3); }); - it("peers emit expected events", { timeout: 100_000 }, async () => { - const baseTask = await defaultTasks.lusCovid.getTask(); - const task: Task<"image", "decentralized"> = { - ...baseTask, - trainingInformation: { - ...baseTask.trainingInformation, - scheme: "decentralized", - aggregationStrategy: "mean", - roundDuration: 1, - minNbOfParticipants: 2, - }, - }; - const url = await startServer({ - ...defaultTasks.lusCovid, - getTask: () => Promise.resolve(task), - }); - const dataset = await datasets.loadLusCOVID(); + /** + * Unit tests with 10 participants + */ + // Mean aggregator + it("ten cifar10 users reach consensus with mean aggregation", { timeout: 300_000 }, async () => { + const baseTask = await defaultTasks.cifar10.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "mean", + epochs: 3, + roundDuration: 1, + minNbOfParticipants: 10, + }, + }; + + const url = await startServer({ + ...defaultTasks.cifar10, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadCifar10(); + + const discos = Array.from( + { length: 10 }, + () => new Disco(task, url, { preprocessOnce: true }), + ); + + try{ + const results = await Promise.all( + discos.map(async (disco) => { + const logs = List(await arrayFromAsync(disco.trainByRound(dataset))); + const lastEpoch = logs.last()?.epochs.last(); + if (lastEpoch === undefined) throw new Error("no epoch ran"); + + return [disco.trainer.model.weights, lastEpoch] as [WeightsContainer, EpochLogs]; + }) + ); + + await expectAllWSToBeClose(results.map(([weights])=>weights)); + }finally{ + await Promise.all(discos.map((disco) => disco.close())); + } + }); + + // Byzantine aggregator + it("ten cifar10 users reach consensus with byzantine aggregation", { timeout: 300_000 }, async () => { + const baseTask = await defaultTasks.cifar10.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "byzantine", + epochs: 3, + roundDuration: 1, + minNbOfParticipants: 10, + privacy: { + byzantineFaultTolerance: { + clippingRadius: 10, + maxIterations: 1, + beta: 0.9, + }, + }, + }, + }; + + const url = await startServer({ + ...defaultTasks.cifar10, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadCifar10(); + + const discos = Array.from( + { length: 10 }, + () => new Disco(task, url, { preprocessOnce: true }), + ); + + try{ + const results = await Promise.all( + discos.map(async (disco) => { + const logs = List(await arrayFromAsync(disco.trainByRound(dataset))); + const lastEpoch = logs.last()?.epochs.last(); + if (lastEpoch === undefined) throw new Error("no epoch ran"); + + return [disco.trainer.model.weights, lastEpoch] as [WeightsContainer, EpochLogs]; + }) + ); + + await expectAllWSToBeClose(results.map(([weights])=>weights)); + }finally{ + await Promise.all(discos.map((disco) => disco.close())); + } + }); + + // Secure aggregator + it("ten cifar10 users reach consensus with secure aggregation", { timeout: 500_000 }, async () => { + const baseTask = await defaultTasks.cifar10.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "secure", + epochs: 10, + roundDuration: 1, + minNbOfParticipants: 10, + maxShareValue: 100, + }, + }; + + const url = await startServer({ + ...defaultTasks.cifar10, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadCifar10(); + + const discos = Array.from( + { length: 10 }, + () => new Disco(task, url, { preprocessOnce: true }), + ); + + try{ + const results = await Promise.all( + discos.map(async (disco) => { + const logs = List(await arrayFromAsync(disco.trainByRound(dataset))); + const lastEpoch = logs.last()?.epochs.last(); + if (lastEpoch === undefined) throw new Error("no epoch ran"); + + return [disco.trainer.model.weights, lastEpoch] as [WeightsContainer, EpochLogs]; + }) + ); + + await expectAllWSToBeClose(results.map(([weights])=>weights)); + }finally{ + await Promise.all(discos.map((disco) => disco.close())); + } + }); + + // syncs model after participants drop below minNbOfParticipants and newcomers join with mean aggregator + it("peers emit expected events", { timeout: 150_000 }, async () => { + const baseTask = await defaultTasks.lusCovid.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "mean", + roundDuration: 1, + minNbOfParticipants: 2, + maxConnectionRetry: 3, + }, + }; + + const url = await startServer({ + ...defaultTasks.lusCovid, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadLusCOVID(); /** - * Then at each round (each call to `disco.trainByRound`) the event cycle is: - * a) During onRoundBeingCommunication, - * 1. the peer notifies the server that they want to join the next round - * 2. finishes by updating the status to "local training" - * (without waiting for a server answer) - * b) local training (the status remains "local training") - * c) During onRoundEndCommunication - * 1. the peer notifies the server that they are ready to share weights - * set status to "connecting to peers" - * 2. wait for the server to answer with the current round's peers list - * this is where the nb of participants is updated - * 3. establish peer-to-peer connections - * 4. set status to "updating model" and exchange weight updates - * - * Given this, it is important to note that calling disco.trainByRound().next() - * for the first time will perform a) and then b) where it stops and yields the round logs. - * Thus, c) isn't called and the weight sharing is not performed during this call to next(). - * Calling next() again will then run c), as well as a) and b) again. - * - * In this test the timeline is: - * - User 1 joins the task by themselves + * Test timeline looks like this: + * - User 1 joins the task * - User 2 joins - * - User 1 leaves - * - User 3 joins + * - User 1 leaves (Since minNbOfParticipants condition is not satisfied, the training stops) + * - User 3 joins (User 3 gets the latest model from User 2 and start local training from that model) * - User 2 & 3 leave */ - /* USER 1 JOINS */ - const discoUser1 = new Disco(task, url, { preprocessOnce: true }); + const discoUser2 = new Disco(task, url, { preprocessOnce: true }); + + // Register listeners for user1 and user2 events const statusUser1 = new Queue(); const nbParticipantsUser1 = new Queue(); - discoUser1.on("status", status => { statusUser1.put(status) }) - discoUser1.on("participants", (participants) => { nbParticipantsUser1.put(participants) }) - const generatorUser1 = discoUser1.trainByRound(dataset) - - // Have User 1 join the task and train locally for one round - const logUser1Round1 = await generatorUser1.next() - expect(logUser1Round1.done).to.be.false - // User 1 did a) and b) so their status should be Training - expect(await statusUser1.next()).equal("local training") - expect(await nbParticipantsUser1.next()).equal(1) - - if (logUser1Round1.done) - throw new Error("User 1 finished training at the 1st round") - // participant list not updated yet (updated at step c)) - expect((logUser1Round1.value).participants).equal(1) - - // Calling next() a 2nd time makes User 1 go to c) where the peer should - // stay stuck awaiting until another participant joins - const logUser1Round2Promise = generatorUser1.next() - expect(await statusUser1.next()).equal("connecting to peers") // tries to connect to peers - expect(await statusUser1.next()).equal("not enough participants") // but has to wait for more participants + const statusUser2 = new Queue(); + const nbParticipantsUser2 = new Queue(); + discoUser1.on("status", (status) => statusUser1.put(status)); + discoUser1.on("participants", (participants) => nbParticipantsUser1.put(participants)); + discoUser2.on("status", (status) => statusUser2.put(status)); + discoUser2.on("participants", (participants) => nbParticipantsUser2.put(participants)); + + let user2Closed = false; + + const generatorUser1 = discoUser1.trainByRound(dataset); + const generatorUser2 = discoUser2.trainByRound(dataset); + + /* ROUND 1 */ + /* USER 1 JOINS */ + const round1User1Promise = generatorUser1.next(); + expect(await statusUser1.next()).equal("not enough participants"); + // We expect only one participant + expect(await nbParticipantsUser1.next()).equal(1); /* USER 2 JOINS */ + /* minNbOfParticipants condition satisfied, local training starts */ + const round1User2Promise = generatorUser2.next(); + await Promise.all([round1User1Promise, round1User2Promise]); + + expect(await statusUser2.next()).equal("local training"); + expect(await statusUser1.next()).equal("local training"); + expect(await nbParticipantsUser1.next()).equal(2); + expect(await nbParticipantsUser2.next()).equal(2); + + /* ROUND 2 - first weight exchange */ + await Promise.all([ + generatorUser1.next(), + generatorUser2.next(), + ]); + + // Both users did connecting -> updating model -> local training + expect(await statusUser1.next()).equal("connecting to peers"); + expect(await statusUser1.next()).equal("updating model"); + expect(await statusUser1.next()).equal("local training"); + expect(await statusUser2.next()).equal("connecting to peers"); + expect(await statusUser2.next()).equal("updating model"); + expect(await statusUser2.next()).equal("local training"); + expect(await nbParticipantsUser1.next()).equal(2); + expect(await nbParticipantsUser2.next()).equal(2); + + // Weights should have converged after exchanging updates + await expectWSToBeClose( + discoUser1.trainer.model.weights, + discoUser2.trainer.model.weights, + ); - const discoUser2 = new Disco(task, url, { preprocessOnce: true }); - const statusUser2 = new Queue(); - const nbParticipantsUser2 = new Queue(); - discoUser2.on("status", status => { statusUser2.put(status) }) - discoUser2.on("participants", (participants) => { nbParticipantsUser2.put(participants) }) - const generatorUser2 = discoUser2.trainByRound(dataset) - - // Have User 2 join the task and train for one round - const logUser2Round1 = await generatorUser2.next() - expect(logUser2Round1.done).to.be.false - if (logUser2Round1.done) - throw new Error("User 2 finished training at the 1st round") - // round payload should contain the number of participants - expect((logUser2Round1.value).participants).equal(2) - expect(await nbParticipantsUser2.next()).equal(2) - // Receive the EnoughParticipants message with the participants - expect(await nbParticipantsUser1.next()).equal(2) - // User 2 did a) and b) - expect(await statusUser2.next()).equal("local training") - // User 1 is still in c) now waiting for user 2 to be ready to exchange weight updates - expect(await statusUser1.next()).equal("connecting to peers") - - /* ROUND 2 */ - - // The server should answer with the round's peers list. - // Peers then exchange updates and then start training locally with the new weights - const logUser2Round2 = await generatorUser2.next() - const logUser1Round2 = await logUser1Round2Promise // the promise can resolve now - expect(logUser1Round2.done).to.be.false - expect(logUser2Round2.done).to.be.false - if (logUser1Round2.done || logUser2Round2.done) - throw new Error("User 1 or 2 finished training at the 2nd round") - // nb of participants should now be updated - expect((logUser1Round2.value).participants).equal(2) - expect((logUser2Round2.value).participants).equal(2) - expect(await nbParticipantsUser2.next()).equal(2) - expect(await nbParticipantsUser1.next()).equal(2) - // User 1 and 2 did c), a) and b) - expect(await statusUser1.next()).equal("updating model") // second to last - expect(await statusUser1.next()).equal("local training") - - expect(await statusUser2.next()).equal("connecting to peers") // back to connecting when user 1 joins - expect(await statusUser2.next()).equal("updating model") - expect(await statusUser2.next()).equal("local training") - - /* USER 1 LEAVES */ - - await discoUser1.close() - // Disconnect updates the number of participants - expect(await nbParticipantsUser1.next()).equal(1) - // User 2 receives the WaitingForMoreParticipants message - expect(await nbParticipantsUser2.next()).equal(1) - // server notifies user 2 to wait - expect(await statusUser2.next()).equal("not enough participants") - // Make user 2 go to c) - const logUser2Round3Promise = generatorUser2.next() - // await new Promise((res, _) => setTimeout(res, statusUpdateTime)) // Wait some time for the status to update - // starts c) and waits for user 3 to join - expect(await statusUser2.next()).equal("connecting to peers") - expect(await statusUser2.next()).equal("not enough participants") + /* USER 2 LEAVES */ + + // ROUND3 starts for User 1 before closing User 2, so User 1 + // enters onRoundEndCommunication and emits "connecting to peers" + const user1WaitingPromise = generatorUser1.next(); + expect(await statusUser1.next()).equal("connecting to peers"); + + await discoUser2.close(); + user2Closed = true; + + // Check if User 1 got a signal that there is not enough participants + expect(await nbParticipantsUser1.next()).equal(1); + expect(await statusUser1.next()).equal("not enough participants"); + + // Snapshot User 1's weights, which will be shared to User 3 for model syncing + const latestWeights = new WeightsContainer( + discoUser1.trainer.model.weights.weights.map((w) => w.clone()), + ); /* USER 3 JOINS */ - // Create User 3 + // Create User 3 and register event listeners const discoUser3 = new Disco(task, url, { preprocessOnce: true }); const statusUser3 = new Queue(); const nbParticipantsUser3 = new Queue(); - discoUser3.on("status", status => { statusUser3.put(status) }) - discoUser3.on("participants", (participants) => { nbParticipantsUser3.put(participants) }) - const generatorUser3 = discoUser3.trainByRound(dataset) - - // User 3 joins mid-training and trains one local round - const logUser3Round1 = await generatorUser3.next() - expect(logUser3Round1.done).to.be.false - if (logUser3Round1.done) - throw new Error("User 3 finished training at the 1st round") - expect((logUser3Round1.value).participants).equal(2) - expect(await nbParticipantsUser3.next()).equal(2) - // User 2 receives the EnoughParticipants message - // User 2 is still in c) waiting for user 3 to share their local update - expect(await nbParticipantsUser2.next()).equal(2) - - // User 3 did a) and b) - expect(await statusUser3.next()).equal("local training") - // User 2 is still in c) waiting for user 3 to be ready to exchange waits - expect(await statusUser2.next()).equal("connecting to peers") - - /* ROUND 3 */ - - // User 3 notifies the server that they are ready to exchange waits - // then user 2 and 3 exchange weight updates - const logUser3Round3 = await generatorUser3.next() - const logUser2Round3 = await logUser2Round3Promise // the promise can resolve now - if (logUser3Round3.done || logUser2Round3.done) - throw new Error("User 1 or 2 finished training at the 3rd round") - - expect(logUser2Round3.value.participants).equal(2) - expect(logUser3Round3.value.participants).equal(2) - expect(await nbParticipantsUser3.next()).equal(2) - expect(await nbParticipantsUser2.next()).equal(2) - - // both user 2 and 3 did c), a) and are now in b) - expect(await statusUser2.next()).equal("updating model") - expect(await statusUser2.next()).equal("local training") - - expect(await statusUser3.next()).equal("connecting to peers") - expect(await statusUser3.next()).equal("updating model") - expect(await statusUser3.next()).equal("local training") - - /* USER 2 AND 3 LEAVE */ - - await discoUser2.close() - expect(await statusUser3.next()).equal("not enough participants") - expect(await nbParticipantsUser3.next()).equal(1) - - await discoUser3.close() + discoUser3.on("status", (status) => statusUser3.put(status)); + discoUser3.on("participants", (participants) => nbParticipantsUser3.put(participants)); + + const waitForUser3ModelSynced = new Promise((resolve) => { + discoUser3.on("modelSynced", (weights) => { + if (weights !== undefined) resolve(weights); + }); + }); + + const generatorUser3 = discoUser3.trainByRound(dataset); + + /* ROUND 3 */ + /* User 3's first round */ + const user3Round1 = await generatorUser3.next(); + expect(user3Round1.done).to.be.false; + expect(user3Round1.value.participants).equal(2); + + // User 3's model should have been synced to User 1's weights before training + const user3SyncedWeights = await waitForUser3ModelSynced; + await expectWSToBeClose(user3SyncedWeights, latestWeights); + + // User 3 did onRoundBeginCommunication and local training + expect(await statusUser3.next()).equal("local training"); + expect(await nbParticipantsUser3.next()).equal(2); + // User 1 learns User 3 joined and is still in onRoundEndCommunication waiting for User 3 to be ready + expect(await nbParticipantsUser1.next()).equal(2); + expect(await statusUser1.next()).equal("connecting to peers"); + + /* ROUND 4 */ + /* first weight exchange between User 1 and User 3 */ + const [user1Round, user3Round2] = await Promise.all([ + user1WaitingPromise, + generatorUser3.next(), + ]); + expect(user1Round.done).to.be.false; + expect(user3Round2.done).to.be.false; + expect(user1Round.value.participants).equal(2); + expect(user3Round2.value.participants).equal(2); + + // Both users did onRoundEndCommunication, onRoundBeginCommunication, and local training + expect(await statusUser1.next()).equal("updating model"); + expect(await statusUser1.next()).equal("not enough participants"); + expect(await statusUser1.next()).equal("local training"); + expect(await nbParticipantsUser1.next()).equal(2); + + expect(await statusUser3.next()).equal("connecting to peers"); + expect(await statusUser3.next()).equal("updating model"); + expect(await statusUser3.next()).equal("local training"); + expect(await nbParticipantsUser3.next()).equal(2); + + + // Weights should have converged between User 1 and User 3 after the exchange + await expectWSToBeClose( + discoUser1.trainer.model.weights, + discoUser3.trainer.model.weights, + ); + + await discoUser3.close(); + await discoUser1.close().catch(() => {}); + if (!user2Closed) await discoUser2.close().catch(() => {}); + }); + + /** + * We test if the latest model syncing is working when new participant + * joins in the middle of the training (when the round > 0). + * + * The test workflow + * 1. Start User1 and User2 starts training + * 2. Let them complete at least one aggregation round + * 3. Start User3 when aggregationRound is larger than 0 + * 4. When User3 starts training, model synchronization should be triggered first + * 5. Compare User3's model weights with User1/User2's latest model weights + */ + it("Model Syncing when new participant joins in the middle of the training", { timeout: 200_000 }, async () => { + const baseTask = await defaultTasks.lusCovid.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "mean", + roundDuration: 1, + minNbOfParticipants: 2, + maxConnectionRetry: 3, + }, + }; + + const url = await startServer({ + ...defaultTasks.lusCovid, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadLusCOVID(); + + const discoUser1 = new Disco(task, url, { preprocessOnce: true }); + const discoUser2 = new Disco(task, url, { preprocessOnce: true }); + const discoUser3 = new Disco(task, url, { preprocessOnce: true }); + + try { + const generatorUser1 = discoUser1.trainByRound(dataset); + const generatorUser2 = discoUser2.trainByRound(dataset); + + await Promise.all([ + generatorUser1.next(), + generatorUser2.next(), + ]); + + await Promise.all([ + generatorUser1.next(), + generatorUser2.next(), + ]); + + // Existing participants should already have the same aggregated model. + await expectWSToBeClose( + discoUser1.trainer.model.weights, + discoUser2.trainer.model.weights, + ); + + const waitForModelSynced = Promise.race([ + new Promise<{ + syncedWeights: WeightsContainer; + providerWeightsUser1: WeightsContainer; + providerWeightsUser2: WeightsContainer; + }>((resolve) => { + discoUser3.on("modelSynced", (weights) => { + if (weights !== undefined) { + resolve({ + syncedWeights: weights, + providerWeightsUser1: new WeightsContainer( + discoUser1.trainer.model.weights.weights.map((w) => w.clone()), + ), + providerWeightsUser2: new WeightsContainer( + discoUser2.trainer.model.weights.weights.map((w) => w.clone()), + ), + }); + } + }); + }), + new Promise((_, reject) => + setTimeout( + () => reject(new Error("Timed out waiting for modelSynced")), + 60_000, + ), + ), + ]); + + const generatorUser3 = discoUser3.trainByRound(dataset); + const user3RoundPromise = generatorUser3.next(); + + await new Promise((resolve) => setTimeout(resolve, 5_000)); + + // The newcomer may ask for synchronization while existing participants are + // already in the next local round. Progress peers until the provider sends the latest model. + for (let attempt = 0; attempt < 5; attempt++) { + const synced = await Promise.race([ + waitForModelSynced.then(() => true), + new Promise((resolve) => + setTimeout(() => resolve(false), 100), + ), + ]); + + if (synced) break; + + await Promise.all([ + generatorUser1.next(), + generatorUser2.next(), + ]); + } + + const { + syncedWeights, + providerWeightsUser1, + providerWeightsUser2, + } = await waitForModelSynced; + + try { + expectWeightsToEqual(syncedWeights, providerWeightsUser1); + } catch { + expectWeightsToEqual(syncedWeights, providerWeightsUser2); + } + + const user3Round = await user3RoundPromise; + expect(user3Round.done).to.be.false; + } finally { + await discoUser1.close(); + await discoUser2.close(); + await discoUser3.close(); + } }); }) diff --git a/webapp/src/components/task_creation_form/TaskCreationForm.vue b/webapp/src/components/task_creation_form/TaskCreationForm.vue index 5ff1d705a..fb875e7ab 100644 --- a/webapp/src/components/task_creation_form/TaskCreationForm.vue +++ b/webapp/src/components/task_creation_form/TaskCreationForm.vue @@ -348,6 +348,25 @@ + +
+ Maximum number of connection retries before disconnecting failed participants. +
+ + +
+ + { // manually interrupt the training cleanupDisco.value = async () => await disco.close() + // For the training completed message + let trainingCompleted = true + try { trainingGenerator.value = disco.train(dataset); @@ -245,6 +248,7 @@ async function startTraining(): Promise { epochsOfRoundLogs.value = List(); } } catch (e) { + trainingCompleted = false; if (e === stopper) { toaster.info("Training stopped"); return; @@ -262,8 +266,22 @@ async function startTraining(): Promise { toaster.error( "Training is not converging. Data potentially needs better preprocessing.", ); + } else if ( + e instanceof Error && + e.message.includes("Client disconnected after connection failure") + ){ + toaster.error( + "Client disconnected after multiple peer connection failure. Please rejoin the training." + ); + } else if ( + e instanceof Error && + e.message.includes("Timeout while waiting for the latest model") + ){ + toaster.error( + "Timeout while waiting for the model syncing. Please rejoin the training." + ); } else { - toaster.error("An error occurred during training"); + toaster.error("An error occurred during training.") } debug("while training: %o", e); } finally { @@ -271,7 +289,10 @@ async function startTraining(): Promise { await cleanupTrainingSession() } - toaster.success("Training successfully completed"); + if (trainingCompleted){ + // printed only when the training is compeleted successfully + toaster.success("Training successfully completed"); + } } async function cleanupTrainingSession() {