From 2900c652eaf838789ce1eef4d74cbd330ff5e5bb Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Wed, 15 Apr 2026 14:52:47 +0200 Subject: [PATCH 01/10] Add ConnectionsReady, StartWeightSharing messages --- .../decentralized/decentralized_client.ts | 11 +++-- discojs/src/client/decentralized/messages.ts | 19 ++++++++- discojs/src/client/messages.ts | 4 ++ .../controllers/decentralized_controller.ts | 42 ++++++++++++++++++- 4 files changed, 69 insertions(+), 7 deletions(-) diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 6f9da6e77..2851ce197 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -152,7 +152,9 @@ export class DecentralizedClient extends Client<"decentralized"> { await this.waitForParticipantsIfNeeded() // Create peer-to-peer connections with all peers for the round await this.establishPeerConnections() - // Exchange weight updates with peers and return aggregated weights + // Wait StartWeightSharing message from the server before exchanging weight updates + await waitMessage(this.server, type.StartWeightSharing) + // Exchange weight updates with peers and return aggregated weights // and then send out the contributions return await this.exchangeWeightUpdates(weights) } @@ -178,8 +180,9 @@ export class DecentralizedClient extends Client<"decentralized"> { try { debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); const receivedMessage = await waitMessage(this.server, type.PeersForRound) - + const peers = Set(receivedMessage.peers) + debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray()); if (this.ownId !== undefined && peers.has(this.ownId)) { throw new Error('received peer list contains our own id') @@ -198,7 +201,9 @@ export class DecentralizedClient extends Client<"decentralized"> { (conn) => this.receivePayloads(conn) ) - debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS()); + // Signal server that all connections with other peers in the round are established + this.server.send({ type: type.ConnectionsReady }); + debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS()); this.#connections = connections } catch (e) { debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e); diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index 626062ad4..ae6b2d891 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -38,6 +38,17 @@ export interface PeersForRound { aggregationRound: number } +// peer sends to server to signal all the connections to other peers +// are established +export interface ConnectionsReady { + type: type.ConnectionsReady +} + +// Server signal each peer to start weight update sharing +export interface StartWeightSharing { + type: type.StartWeightSharing; +} + /// Phase 1 communication (between peers) export interface Payload { @@ -55,13 +66,15 @@ export type MessageFromServer = SignalForPeer | PeersForRound | WaitingForMoreParticipants | - EnoughParticipants + EnoughParticipants | + StartWeightSharing export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | - JoinRound + JoinRound | + ConnectionsReady export type PeerMessage = Payload @@ -80,6 +93,7 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer { return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID) case type.WaitingForMoreParticipants: case type.EnoughParticipants: + case type.StartWeightSharing: return true } @@ -97,6 +111,7 @@ export function isMessageToServer (o: unknown): o is MessageToServer { 'signal' in o // TODO check signal content? case type.JoinRound: case type.PeerIsReady: + case type.ConnectionsReady: return true } diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index f5b5f9bb4..b0e9dc4e7 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -24,6 +24,10 @@ export enum type { // Message forwarded by the server from a client to another client // to establish a peer-to-peer (WebRTC) connection SignalForPeer, + // Message sent by nodes to server to signal all connections are established + ConnectionsReady, + // Sent by the server to signal nodes proceed to weight update sharing + StartWeightSharing, // The weight update Payload, diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index 37ea428f4..fa73ea768 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -21,6 +21,7 @@ export class DecentralizedController< // the node has already sent a PeerIsReady message) // We wait for all peers to be ready to exchange weight updates #roundPeers = Map() + #connectFinishedNodes = Map() #aggregationRound = 0 handle (ws: WebSocket): void { @@ -84,6 +85,11 @@ export class DecentralizedController< this.connections.get(msg.peer)?.send(msgpack.encode(forward)) break } + case MessageTypes.ConnectionsReady: { + this.#connectFinishedNodes = this.#connectFinishedNodes.set(peerId, true) + this.signalWeightSharing() + break + } default: { const _: never = msg throw new Error('should never happen') @@ -145,9 +151,41 @@ export class DecentralizedController< } return [conn, encoded] as [WebSocket, Buffer] }).forEach(([conn, encoded]) => { conn.send(encoded) }) + + // Initialize connectFinishedNodes with all peers set to false + this.#connectFinishedNodes = this.#roundPeers.map(() => false) as Map + this.#aggregationRound++ + } + + /** + * Check if all the participants of the round finished connecting + * with other peers in the round + * If so, send StartWeightSharing message to signal peers to proceed + */ + private signalWeightSharing(): void { + if (!this.#connectFinishedNodes.every((ready) => ready)) + return + this.#roundPeers.keySeq() + .map((id) => { + const startSignal = { + type: MessageTypes.StartWeightSharing, + } + debug("Signaling weight sharing to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(startSignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + // empty the list of peers for the next round this.#roundPeers = Map() - this.#aggregationRound++ + this.#connectFinishedNodes = Map() } } - From dd22a7d1ba1c3f0396a26909e22bfeaafe424e6c Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Wed, 15 Apr 2026 16:06:11 +0200 Subject: [PATCH 02/10] Update gitignore --- .gitignore | 3 +++ datasets/.gitignore | 12 +++--------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 4f09db6ff..8e885b6fc 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ dist/ .idea/ .vscode/ *.DS_Store + +# venv +.venv/ diff --git a/datasets/.gitignore b/datasets/.gitignore index 1ae84a880..e644c626a 100644 --- a/datasets/.gitignore +++ b/datasets/.gitignore @@ -2,17 +2,11 @@ /2_QAID_1.masked.reshaped.squared.224.png /9-mnist-example.png /CIFAR10/ -/cifar10-agents -/cifar10-example.png -/cifar10-labels.csv +/cifar10* /simple_face /simple_face-example.png -/titanic_test.csv -/titanic_train.csv -/titanic_train_with_nan.csv -/titanic_test_with_nan.csv -/titanic_wrong_number_columns.csv -/titanic_wrong_passengerID.csv +/titanic* +/mnist* # wikitext /wikitext/ From 7d0e6344d6183c481ce90f3430f761a3b8cb0c47 Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Wed, 15 Apr 2026 16:17:39 +0200 Subject: [PATCH 03/10] Fix lint error --- server/src/controllers/decentralized_controller.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index fa73ea768..2e9f94c28 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -153,7 +153,7 @@ export class DecentralizedController< }).forEach(([conn, encoded]) => { conn.send(encoded) }) // Initialize connectFinishedNodes with all peers set to false - this.#connectFinishedNodes = this.#roundPeers.map(() => false) as Map + this.#connectFinishedNodes = this.#roundPeers.map(() => false) this.#aggregationRound++ } From 947d2c6644a7ebdce19eb5ba4c50cb87c2bf38ae Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Thu, 30 Apr 2026 12:32:36 +0200 Subject: [PATCH 04/10] Add server timeout and retry handling for peer connection failures --- .../decentralized/decentralized_client.ts | 40 +++++- discojs/src/client/decentralized/messages.ts | 21 ++- discojs/src/client/messages.ts | 4 + .../controllers/decentralized_controller.ts | 125 +++++++++++++++++- .../components/training/TrainerDashboard.vue | 16 ++- 5 files changed, 196 insertions(+), 10 deletions(-) diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 2851ce197..0e9611531 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -150,11 +150,41 @@ export class DecentralizedClient extends Client<"decentralized"> { this.saveAndEmit("connecting to peers") // First we check if we are waiting for more participants before sending our weight update await this.waitForParticipantsIfNeeded() - // Create peer-to-peer connections with all peers for the round - await this.establishPeerConnections() - // Wait StartWeightSharing message from the server before exchanging weight updates - await waitMessage(this.server, type.StartWeightSharing) - // Exchange weight updates with peers and return aggregated weights // and then send out the contributions + + while(true){ + // Create peer-to-peer connections with all peers for the round + await this.establishPeerConnections() + + // Wait for connection related messages from the server before exchanging weight updates + // (1) If the client receives a StartWeightSharing message, it proceeds to weight update exchange + // (2) If it receives a RetryPeerConnections message, it retries peer connection establishment + // (3) After multiple retires, if the connection is still unsuccessful, the server starts excluding nodes from the round + // and sends a ConnectionFail message to those nodes + // (4) Upon receiving ConnectionFail, the client disconnects from the server + const msg = await Promise.race([ + waitMessage(this.server, type.StartWeightSharing), + waitMessage(this.server, type.RetryPeerConnections), + waitMessage(this.server, type.ConnectionFail), + ]) + + if (msg.type === type.StartWeightSharing){ + break + } else if (msg.type === type.RetryPeerConnections){ + debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`) + // clear the communication round peer pool + await this.#pool?.shutdown() + this.#pool = new PeerPool(this.ownId) + // clear the connections + this.#connections = Map() + this.setAggregatorNodes(Set(this.ownId)) + continue + } else if (msg.type === type.ConnectionFail){ + debug(`[${shortenId(this.ownId)}] disconnect from the server`) + await this.disconnect() + throw new Error("Client disconnected after connection failure") + } + } + // Exchange weight updates with peers and return aggregated weights return await this.exchangeWeightUpdates(weights) } diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index ae6b2d891..54dc234c4 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -44,9 +44,20 @@ export interface ConnectionsReady { type: type.ConnectionsReady } -// Server signal each peer to start weight update sharing +// Server signals each peer to start weight update sharing export interface StartWeightSharing { - type: type.StartWeightSharing; + type: type.StartWeightSharing +} + +// Server signals peers to reestablish peer connections +export interface RetryPeerConnections { + type: type.RetryPeerConnections + aggregationRound: number +} + +// Server signals a node that the connection with other peers failed +export interface ConnectionFail { + type: type.ConnectionFail } /// Phase 1 communication (between peers) @@ -67,7 +78,9 @@ export type MessageFromServer = PeersForRound | WaitingForMoreParticipants | EnoughParticipants | - StartWeightSharing + StartWeightSharing | + RetryPeerConnections | + ConnectionFail export type MessageToServer = ClientConnected | @@ -94,6 +107,8 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer { case type.WaitingForMoreParticipants: case type.EnoughParticipants: case type.StartWeightSharing: + case type.RetryPeerConnections: + case type.ConnectionFail: return true } diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index b0e9dc4e7..c28e8b7f5 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -28,6 +28,10 @@ export enum type { ConnectionsReady, // Sent by the server to signal nodes proceed to weight update sharing StartWeightSharing, + // Sent by the server to signal nodes reestablish connections + RetryPeerConnections, + // Sent by the server to signal that the node's connection was not successful + ConnectionFail, // The weight update Payload, diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index 2e9f94c28..d1f9c1f5b 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -23,6 +23,8 @@ export class DecentralizedController< #roundPeers = Map() #connectFinishedNodes = Map() #aggregationRound = 0 + #timeout?: NodeJS.Timeout + #connectionRetry= 0 // number of connection retrial for specific aggregationRound handle (ws: WebSocket): void { const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants @@ -154,7 +156,11 @@ export class DecentralizedController< // Initialize connectFinishedNodes with all peers set to false this.#connectFinishedNodes = this.#roundPeers.map(() => false) - this.#aggregationRound++ + // Change the peer states to not ready + this.#roundPeers = this.#roundPeers.map(() => false) + + // Start timeout to check peer connections are successful + this.startTimeout() } /** @@ -163,8 +169,14 @@ export class DecentralizedController< * If so, send StartWeightSharing message to signal peers to proceed */ private signalWeightSharing(): void { + // Return if not all participants are ready if (!this.#connectFinishedNodes.every((ready) => ready)) return + + // Stop the timeout + this.clearTimeout() + + // Send round participants StartWeightSharing messages this.#roundPeers.keySeq() .map((id) => { const startSignal = { @@ -187,5 +199,116 @@ export class DecentralizedController< // empty the list of peers for the next round this.#roundPeers = Map() this.#connectFinishedNodes = Map() + this.#aggregationRound++ + } + + /** + * Set a timeout to check peer connections establishment + */ + private startTimeout(maxTime: number = 60_000): void { + this.#timeout = setTimeout(() => { + this.handleTimeout() + }, maxTime) + } + + /** + * Clear previously set timeout once all peer connections + * are established before the timeout + */ + private clearTimeout(): void { + if (this.#timeout !== undefined){ + clearTimeout(this.#timeout) + this.#timeout = undefined + } + } + + /** + * Called when a timeout occurs during peer connection + * Signals peers to discard existing connections and + * reestablish connections with the current set of peers + */ + private handleTimeout(): void { + debug(`Connection setup timeout for round ${this.#aggregationRound}, Retrying with same peers`) + // Increment the connection retry count + this.#connectionRetry += 1; + + // If the number of retries exceeds the threshold, exclude the failed peers from the roundPeers + if (this.#connectionRetry >= 3){ + const numFailedClient = this.#connectFinishedNodes.valueSeq().count((val) => val === false) + const remainingPeers = this.#roundPeers.size - numFailedClient + + // Exclude the failed peers + this.#connectFinishedNodes.forEach((connected, nodeId) => { + if (!connected){ + // If the node failed connection, exclude from #roundPeers + this.#roundPeers = this.#roundPeers.delete(nodeId) + // Signal the node that connection is failed for that node + const conn = this.connections.get(nodeId) + if (conn === undefined) { + throw new Error(`peer ${nodeId} marked as ready but not connection to it`) + } + const failSignal : messages.ConnectionFail = { + type: MessageTypes.ConnectionFail + } + const encoded = msgpack.encode(failSignal) + conn.send(encoded) + } + }) + + // If excluding failed peers would leave too few participants, + // restart the round + // TODO: We need to wait until minNbOfParticipants is satisfied + if (remainingPeers < this.task.trainingInformation.minNbOfParticipants){ + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // empty the list of peers for the new round + // round number is not increased since this round failed + this.#roundPeers = Map() + this.#connectFinishedNodes = Map() + this.#connectionRetry = 0 + return + } + } + + // Retry peer connection with the currently remaining round peers + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + this.#connectionRetry = 0 + + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) } } + diff --git a/webapp/src/components/training/TrainerDashboard.vue b/webapp/src/components/training/TrainerDashboard.vue index c7493060e..1ecc775d2 100644 --- a/webapp/src/components/training/TrainerDashboard.vue +++ b/webapp/src/components/training/TrainerDashboard.vue @@ -222,6 +222,9 @@ async function startTraining(): Promise { // manually interrupt the training cleanupDisco.value = async () => await disco.close() + // For the training completed message + let trainingCompleted = true + try { trainingGenerator.value = disco.train(dataset); @@ -245,6 +248,7 @@ async function startTraining(): Promise { epochsOfRoundLogs.value = List(); } } catch (e) { + trainingCompleted = false; if (e === stopper) { toaster.info("Training stopped"); return; @@ -262,6 +266,13 @@ async function startTraining(): Promise { toaster.error( "Training is not converging. Data potentially needs better preprocessing.", ); + } else if ( + e instanceof Error && + e.message.includes("Client disconnected after connection failure") + ){ + toaster.error( + "Client disconnected after multiple peer connection failure. Please rejoin the training." + ); } else { toaster.error("An error occurred during training"); } @@ -271,7 +282,10 @@ async function startTraining(): Promise { await cleanupTrainingSession() } - toaster.success("Training successfully completed"); + if (trainingCompleted){ + // printed only when the training is compeleted successfully + toaster.success("Training successfully completed"); + } } async function cleanupTrainingSession() { From b443f53babcf88f0f98c939f45210029d465225f Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Thu, 7 May 2026 17:51:55 +0200 Subject: [PATCH 05/10] Implement decentralized model synchronization for clients joining during training --- discojs/src/client/client.ts | 13 ++ .../decentralized/decentralized_client.ts | 120 ++++++++++++++++- discojs/src/client/decentralized/messages.ts | 50 ++++++- discojs/src/client/messages.ts | 8 ++ discojs/src/training/disco.ts | 19 +++ .../controllers/decentralized_controller.ts | 122 ++++++++++++------ .../components/training/TrainerDashboard.vue | 9 +- 7 files changed, 292 insertions(+), 49 deletions(-) diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 9e1298b93..e0fb26372 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -14,6 +14,7 @@ import type { EventConnection } from './event_connection.js' import type { Aggregator } from '../aggregator/index.js' import { EventEmitter } from '../utils/event_emitter.js' import { type } from "./messages.js"; +import { ModelWeightAccess } from "../training/disco.js"; const debug = createDebug("discojs:client"); @@ -38,6 +39,9 @@ export abstract class Client extends EventEmitter<{ */ protected promiseForMoreParticipants: Promise | undefined = undefined; + // Interface to access trainer's model weights + protected modelWeightAccess?: ModelWeightAccess; + /** * When the server notifies the client that they can resume training * after waiting for more participants, we want to be able to display what @@ -56,6 +60,15 @@ export abstract class Client extends EventEmitter<{ ) { super() } + + /** + * Used for decentralized learning. + * Set the interface used by client to access to trainer's model weights. + * Disco object provides this access. + */ + setModelWeightAccess(modelWeightAccess: ModelWeightAccess){ + this.modelWeightAccess = modelWeightAccess + } /** * Communication callback called at the beginning of every training round. diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 0e9611531..93aaa3c78 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -26,6 +26,15 @@ export class DecentralizedClient extends Client<"decentralized"> { #pool?: PeerPool #connections?: Map + // Flag if this model requires model synchronization + #modelSyncNeeded?: boolean + + // Check if the training round is in progress + // Used to get the latest model for model synchronization + #isRoundInTraining = false + #roundFinishedPromise?: Promise + #resolveRoundFinished?: () => void // contains resolver + // Used to handle timeouts and promise resolving after calling disconnect private get isDisconnected() : boolean { return this._server === undefined @@ -69,6 +78,24 @@ export class DecentralizedClient extends Client<"decentralized"> { this.#pool.signal(event.peer, event.signal) }) + // Listen if the client is selected as a model provider node for a newly joining client. + // Upon receiving the signal, this client establishes a connection with the newcomer + // and sends the latest model weights. + this.server.on(type.SignalNewPeer, async (event) => { + if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined') + const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) + + const newcomerConn = syncConnection.get(event.newNode) + + if (newcomerConn === undefined){ + // if connection with newly joining client fails, print debug message + // and return + debug(`Cannot connect to newly joined client [${event.newNode}]`) + return + } + await this.sendModel(newcomerConn) + }) + // c.f. setupServerCallbacks doc for explanation let receivedEnoughParticipants = false this.setupServerCallbacks(() => receivedEnoughParticipants = true) @@ -79,8 +106,9 @@ export class DecentralizedClient extends Client<"decentralized"> { this.server.send(msg) const { id, waitForMoreParticipants, - nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) - + nbOfParticipants, joinedMidTraining } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) + + this.#modelSyncNeeded = joinedMidTraining this.nbOfParticipants = nbOfParticipants @@ -129,8 +157,38 @@ export class DecentralizedClient extends Client<"decentralized"> { * When connected, one peer creates a promise for every other peer's weight update * and waits for it to resolve. * + * If a client joined the training after the first round, + * model syncing happens first to get the latest model. */ override async onRoundBeginCommunication(): Promise { + if (this.#modelSyncNeeded) { + // 1. If model sync is needed, send server a request + this.server.send({ type: type.ModelSyncRequest }) + + // 2. Get the provider information from the server + const providerInfo = await waitMessageWithTimeout(this.server, type.SignalModelProvider, 30_000, "Timeout while waiting for the latest model provider") + + if (this.#pool === undefined) { + throw new Error('peer pool is undefined, make sure to call `client.connect()` first') + } + + // 3. Connect with model provider client and get the latest model + const syncConnection = await this.#pool.getPeers( + Set([providerInfo.providerNode]), + this.server, + ()=>{} + ) + const providerConn = syncConnection.get(providerInfo.providerNode) + + if (providerConn === undefined){ + throw new Error("The latest model provider is not connected") + } + + const latestModel = await this.receiveModel(providerConn) + this.modelWeightAccess?.setModelWeight(latestModel) + this.#modelSyncNeeded = false + } + // Notify the server we want to join the next round so that the server // waits for us to be ready before sending the list of peers for the round this.server.send({ type: type.JoinRound }) @@ -149,9 +207,11 @@ export class DecentralizedClient extends Client<"decentralized"> { // Once enough new participants join we can display the previous status again this.saveAndEmit("connecting to peers") // First we check if we are waiting for more participants before sending our weight update - await this.waitForParticipantsIfNeeded() while(true){ + // Wait until enough participants are available before continuing the round + await this.waitForParticipantsIfNeeded() + // Create peer-to-peer connections with all peers for the round await this.establishPeerConnections() @@ -185,7 +245,18 @@ export class DecentralizedClient extends Client<"decentralized"> { } } // Exchange weight updates with peers and return aggregated weights - return await this.exchangeWeightUpdates(weights) + let aggregatedWeight: WeightsContainer + try{ + aggregatedWeight = await this.exchangeWeightUpdates(weights) + } finally { + // Mark the round as finished so that model synchronization can proceed + this.#isRoundInTraining = false + this.#resolveRoundFinished?.() + this.#roundFinishedPromise = undefined + this.#resolveRoundFinished = undefined + } + + return aggregatedWeight } /** @@ -211,6 +282,12 @@ export class DecentralizedClient extends Client<"decentralized"> { debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); const receivedMessage = await waitMessage(this.server, type.PeersForRound) + this.#isRoundInTraining = true + // Generate a promise that resolves when round training finishes + this.#roundFinishedPromise = new Promise((resolve) => { + this.#resolveRoundFinished = resolve + }) + const peers = Set(receivedMessage.peers) debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray()); @@ -338,4 +415,39 @@ export class DecentralizedClient extends Client<"decentralized"> { } return await this.aggregationResult } + + /** + * Receive model from the model provider. + */ + private async receiveModel(providerConn: PeerConnection): Promise{ + const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model") + + const decoded = serialization.weights.decode(message.model) + return decoded + } + + /** + * Send the latest available model to a newly joining client. + * If the current training round is in progress, wait until the round finishes + * and receive the latest aggregated model. + */ + private async sendModel(newcomerConn: PeerConnection): Promise { + if (this.#isRoundInTraining){ + await this.#roundFinishedPromise + } + + const model = this.modelWeightAccess?.getModelWeight() + + if (model === undefined){ + debug("Failed to get the latest model from model provider client") + return + } + const encoded = await serialization.weights.encode(model) + + const message: messages.SharedModel = { + type: type.SharedModel, + model: encoded + } + newcomerConn.send(message) + } } diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index 54dc234c4..1c56f45f8 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -12,6 +12,7 @@ export interface NewDecentralizedNodeInfo { id: NodeID waitForMoreParticipants: boolean nbOfParticipants: number + joinedMidTraining: boolean } // WebRTC signal to forward to other node @@ -60,6 +61,32 @@ export interface ConnectionFail { type: type.ConnectionFail } +// Nodes joining in the middle of the training send to server +// to request the latest model before starting local training +export interface ModelSyncRequest { + type: type.ModelSyncRequest +} + +// Server signals a node that shares the lastest model with node +// who joined in the middle of the training +export interface SignalNewPeer { + type: type.SignalNewPeer + newNode: NodeID +} + +// Server signals new node joining in the middle of the training +// about the model provider node +export interface SignalModelProvider { + type: type.SignalModelProvider + providerNode: NodeID +} + +// Sent by client to another client to share the latest model +export interface SharedModel { + type: type.SharedModel + model: serialization.Encoded +} + /// Phase 1 communication (between peers) export interface Payload { @@ -80,16 +107,21 @@ export type MessageFromServer = EnoughParticipants | StartWeightSharing | RetryPeerConnections | - ConnectionFail + ConnectionFail | + SignalModelProvider | + SignalNewPeer export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | JoinRound | - ConnectionsReady + ConnectionsReady | + ModelSyncRequest -export type PeerMessage = Payload +export type PeerMessage = + Payload | + SharedModel export function isMessageFromServer (o: unknown): o is MessageFromServer { if (!hasMessageType(o)) return false @@ -101,14 +133,17 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer { typeof o.waitForMoreParticipants === 'boolean' case type.SignalForPeer: return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + 'signal' in o case type.PeersForRound: return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID) + case type.SignalNewPeer: + return 'newNode' in o && isNodeID(o.newNode) case type.WaitingForMoreParticipants: case type.EnoughParticipants: case type.StartWeightSharing: case type.RetryPeerConnections: case type.ConnectionFail: + case type.SignalModelProvider: return true } @@ -123,10 +158,11 @@ export function isMessageToServer (o: unknown): o is MessageToServer { return true case type.SignalForPeer: return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + 'signal' in o case type.JoinRound: case type.PeerIsReady: case type.ConnectionsReady: + case type.ModelSyncRequest: return true } @@ -142,6 +178,10 @@ export function isPeerMessage (o: unknown): o is PeerMessage { 'peer' in o && isNodeID(o.peer) && 'payload' in o && serialization.isEncoded(o.payload) ) + case type.SharedModel: + return ( + 'model' in o && serialization.isEncoded(o.model) + ) } return false diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index c28e8b7f5..eb6682694 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -34,6 +34,14 @@ export enum type { ConnectionFail, // The weight update Payload, + // Sent by nodes to the server to request the latest model + ModelSyncRequest, + // Sent by the server to nodes to share the provider node info + SignalModelProvider, + // Sent by the server to nodes who was selected as a model provider node + SignalNewPeer, + // Sent by node to node to share the latest model weights + SharedModel, /* Federated */ // The server answers the ClientConnected message with the necessary information diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 7dd019b2d..311ad1d7a 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -16,6 +16,7 @@ import type { Network, Task, } from "../index.js"; +import { WeightsContainer } from "../index.js"; import type { Aggregator } from "../aggregator/index.js"; import { getAggregator } from "../aggregator/index.js"; import { enumerate, split } from "../utils/async_iterator.js"; @@ -70,6 +71,15 @@ function buildSummaryLog(roundNum: number, epochNum: number, roundLogs: RoundLog } } +/** + * Interface providing an access to trainer's model weights. + * Used for model synchronization to retrieve and set the latest model. + */ +export interface ModelWeightAccess{ + getModelWeight(): WeightsContainer; + setModelWeight(weight: WeightsContainer): void; +} + /** * Top-level class handling distributed training from a client's perspective. It is meant to be * a convenient object providing a reduced yet complete API that wraps model training and @@ -127,6 +137,15 @@ export class Disco extends EventEmitter<{ this.#client = client; this.#task = task; this.trainer = new Trainer(task, client); + // Set ModelWeightAccess of the client + this.#client.setModelWeightAccess({ + getModelWeight: () => { + return new WeightsContainer(this.trainer.model.weights.weights.map(t => t.clone())); + }, + setModelWeight: (weights) => { + this.trainer.model.weights = weights; + } + }); // Simply propagate the training status events emitted by the client this.#client.on("status", (status) => this.emit("status", status)); this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants)); diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index d1f9c1f5b..ade4c8381 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -2,7 +2,7 @@ import createDebug from "debug"; import { v4 as randomUUID } from 'uuid' import * as msgpack from "@msgpack/msgpack"; import type WebSocket from 'ws' -import { Map } from 'immutable' +import { Map, Set } from 'immutable' import { client, DataType } from "@epfml/discojs"; @@ -24,7 +24,12 @@ export class DecentralizedController< #connectFinishedNodes = Map() #aggregationRound = 0 #timeout?: NodeJS.Timeout - #connectionRetry= 0 // number of connection retrial for specific aggregationRound + + // number of connection retries for the training round + #connectionRetry = 0 + // Client selected to provide the latest model to peers + // joining in the middle of training + #providerNode?: client.NodeID handle (ws: WebSocket): void { const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants @@ -49,12 +54,18 @@ export class DecentralizedController< debug(`peer [%s] joined ${this.task.id}`, shortId) this.connections = this.connections.set(peerId, ws) + let joinedMidTraining = false + if (this.#aggregationRound > 0){ + joinedMidTraining = true + } + // Answer with client id in an NewNodeInfo message const msg: messages.NewDecentralizedNodeInfo = { type: MessageTypes.NewDecentralizedNodeInfo, id: peerId, nbOfParticipants: this.connections.size, - waitForMoreParticipants: this.connections.size < minNbOfParticipants + waitForMoreParticipants: this.connections.size < minNbOfParticipants, + joinedMidTraining: joinedMidTraining, } ws.send(msgpack.encode(msg), { binary: true }) // Send an update to participants if we can start/resume training @@ -88,10 +99,40 @@ export class DecentralizedController< break } case MessageTypes.ConnectionsReady: { + // Select the first client that finishes peer connections + // as the model provider for clients joined mid-training + const numconnFinishedNodes = this.#connectFinishedNodes.reduce((acc, val) => acc + (val ? 1 : 0), 0) + if (!numconnFinishedNodes){ + this.#providerNode = peerId + } + this.#connectFinishedNodes = this.#connectFinishedNodes.set(peerId, true) this.signalWeightSharing() break } + case MessageTypes.ModelSyncRequest: { + // Upon receiving a model sync request, send relevant client information + // to both the model provider and the newly joined client + if (!this.#providerNode) { + debug("There is no provider node to share the latest model") + break + } + + // Signal the newly joined client with the provider client's information + const providerInfo: messages.SignalModelProvider = { + type: MessageTypes.SignalModelProvider, + providerNode: this.#providerNode + } + this.connections.get(peerId)?.send(msgpack.encode(providerInfo)) + + // Signal the provider client with newly joined client's information + const newNodeInfo: messages.SignalNewPeer = { + type: MessageTypes.SignalNewPeer, + newNode: peerId + } + this.connections.get(this.#providerNode)?.send(msgpack.encode(newNodeInfo)) + break + } default: { const _: never = msg throw new Error('should never happen') @@ -106,13 +147,13 @@ export class DecentralizedController< // Remove the participant when the websocket is closed this.connections = this.connections.delete(peerId) this.#roundPeers = this.#roundPeers.delete(peerId) + this.#connectFinishedNodes = this.#connectFinishedNodes.delete(peerId) debug("client [%s] left", shortId) // Check if we are already waiting for new participants to join if (this.waitingForMoreParticipants) return // If no, check if we are still above the minimum number of participant required if (this.connections.size >= minNbOfParticipants) { - // Check if remaining peers are all ready to exchange weight updates this.sendPeersForRoundIfNeeded() return } @@ -228,20 +269,20 @@ export class DecentralizedController< * reestablish connections with the current set of peers */ private handleTimeout(): void { + this.clearTimeout() debug(`Connection setup timeout for round ${this.#aggregationRound}, Retrying with same peers`) // Increment the connection retry count this.#connectionRetry += 1; - // If the number of retries exceeds the threshold, exclude the failed peers from the roundPeers + // If the number of retries exceeds the threshold, exclude the failed peers from the round + // and retry peer connection only with the remaining peers if (this.#connectionRetry >= 3){ - const numFailedClient = this.#connectFinishedNodes.valueSeq().count((val) => val === false) - const remainingPeers = this.#roundPeers.size - numFailedClient - // Exclude the failed peers this.#connectFinishedNodes.forEach((connected, nodeId) => { if (!connected){ // If the node failed connection, exclude from #roundPeers this.#roundPeers = this.#roundPeers.delete(nodeId) + this.#connectFinishedNodes = this.#connectFinishedNodes.delete(nodeId) // Signal the node that connection is failed for that node const conn = this.connections.get(nodeId) if (conn === undefined) { @@ -255,41 +296,41 @@ export class DecentralizedController< } }) - // If excluding failed peers would leave too few participants, - // restart the round - // TODO: We need to wait until minNbOfParticipants is satisfied - if (remainingPeers < this.task.trainingInformation.minNbOfParticipants){ - this.#roundPeers.keySeq() - .map((id) => { - const retrySignal = { - type: MessageTypes.RetryPeerConnections, - } - debug("Signaling connection retry to: %o", id.slice(0, 4)) + // Restart the round with remaining clients + this.#roundPeers.keySeq() + .map((id) => { + const retrySignal = { + type: MessageTypes.RetryPeerConnections, + } + debug("Signaling connection retry to: %o", id.slice(0, 4)) - const encoded = msgpack.encode(retrySignal) - return [id, encoded] as [client.NodeID, Buffer] - }) - .map(([id, encoded]) => { - const conn = this.connections.get(id) - if (conn === undefined) { - throw new Error(`peer ${id} marked as ready but not connection to it`) - } - return [conn, encoded] as [WebSocket, Buffer] - }) - .forEach(([conn, encoded]) => {conn.send(encoded)}) - - // empty the list of peers for the new round - // round number is not increased since this round failed - this.#roundPeers = Map() - this.#connectFinishedNodes = Map() - this.#connectionRetry = 0 - return - } + const encoded = msgpack.encode(retrySignal) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }) + .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // Reset the ready and connection status of roundPeers + this.#roundPeers = this.#roundPeers.map(() => false) + this.#connectFinishedNodes = this.#roundPeers.map(() => false) + + // Reset the connectionRetry since we excluded the failed clients + this.#connectionRetry = 0 + // Restart the timeout after sending retry messages + this.startTimeout() + return } - // Retry peer connection with the currently remaining round peers + // Retry peer connection with original roundPeers + // Reset the ready and connection status of roundPeers + this.#roundPeers = this.#roundPeers.map(() => false) this.#connectFinishedNodes = this.#roundPeers.map(() => false) - this.#connectionRetry = 0 this.#roundPeers.keySeq() .map((id) => { @@ -309,6 +350,9 @@ export class DecentralizedController< return [conn, encoded] as [WebSocket, Buffer] }) .forEach(([conn, encoded]) => {conn.send(encoded)}) + + // Restart the timeout after sending retry messages + this.startTimeout() } } diff --git a/webapp/src/components/training/TrainerDashboard.vue b/webapp/src/components/training/TrainerDashboard.vue index 1ecc775d2..3680175b1 100644 --- a/webapp/src/components/training/TrainerDashboard.vue +++ b/webapp/src/components/training/TrainerDashboard.vue @@ -273,8 +273,15 @@ async function startTraining(): Promise { toaster.error( "Client disconnected after multiple peer connection failure. Please rejoin the training." ); + } else if ( + e instanceof Error && + e.message.includes("Timeout while waiting for the latest model") + ){ + toaster.error( + "Timeout while waiting for the model syncing. Please rejoin the training." + ); } else { - toaster.error("An error occurred during training"); + toaster.error("An error occurred during training.") } debug("while training: %o", e); } finally { From 9c7c4883bd46fc9826969f3779df65e4bdb77def Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Tue, 2 Jun 2026 23:07:14 +0200 Subject: [PATCH 06/10] Improve model syncing and add max retry parameter --- discojs/src/client/client.ts | 5 + discojs/src/client/decentralized/README.md | 54 +++ .../decentralized/decentralized_client.ts | 50 +-- discojs/src/default_tasks/cifar10.ts | 3 +- discojs/src/default_tasks/mnist.ts | 3 +- discojs/src/task/training_information.ts | 3 + discojs/src/training/disco.ts | 4 +- discojs/src/training/trainer.ts | 1 + .../controllers/decentralized_controller.ts | 26 +- server/tests/e2e/decentralized.spec.ts | 313 +++++++++++++++++- .../task_creation_form/TaskCreationForm.vue | 26 +- 11 files changed, 457 insertions(+), 31 deletions(-) create mode 100644 discojs/src/client/decentralized/README.md diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index e0fb26372..9999acc5f 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -25,6 +25,7 @@ const debug = createDebug("discojs:client"); export abstract class Client extends EventEmitter<{ status: RoundStatus; participants: number; + modelSynced: WeightsContainer | undefined; }> { // Own ID provided by the network's server. protected _ownId?: NodeID @@ -206,6 +207,10 @@ export abstract class Client extends EventEmitter<{ return await serialization.model.decode(encoded) } + public finishRound(): void{ + // DecentralizedClient override the method to clean up round state + } + /** * Number of contributors to a collaborative session * If decentralized, it should be the number of peers diff --git a/discojs/src/client/decentralized/README.md b/discojs/src/client/decentralized/README.md new file mode 100644 index 000000000..07d581603 --- /dev/null +++ b/discojs/src/client/decentralized/README.md @@ -0,0 +1,54 @@ +# Peer Connection in Decentralized Learning +This document describes how peer connections for decentralized learning are established. + +Peer connections for decentralized learning are coordinated by the server. However, model weight updates are shared only between peers. + +## Peer Connection and Round Participation + +Clients first connect to the server. The server then coordinates which clients should participate in each decentralized training round. + +At the beginning of each round, clients send `JoinRound` messages to the server. After that, they send `PeerIsReady` messages to notify the server that they are ready to establish peer connections and exchange model updates. + +For each round, the server assumes that all currently connected clients are participating, except for clients that are still syncing their model after joining in the middle of training. Once the server has received `PeerIsReady` messages from all expected participants for the round, it sends a `PeersForRound` message to each participant. This message tells clients which peers they should connect to for the current round. + +After receiving `PeersForRound`, participants establish peer connections and proceed with decentralized learning. + +### Connection Ready Check and Signaling Weight Sharing + +Before clients start sharing model weights, the server checks that all participants have successfully completed peer connection setup. This prevents faster clients from starting weight sharing while slower clients are still establishing connections. +The process is as follows: +1. After completing peer connections, each client sends a `ConnectionsReady` message to the server. +2. The server counts the number of received `ConnectionsReady` messages. +3. Once the number of `ConnectionsReady` messages matches the number of expected participants for the round, the server sends a `StartWeightSharing` message to all round participants. +4. Clients only start sharing model updates after receiving `StartWeightSharing`. + +### Connection Retries and Failed Client Disconnection + +Peer connections may fail, so decentralized training includes a retry mechanism. The maximum number of retries is controlled by `maxConnectionRetry`, which is specified in the task training information. + +The retry mechanism is triggered when the server times out while waiting for `ConnectionsReady` messages. +The process is as follows: +1. If the number of retries is still below `maxConnectionRetry`, the server sends a `RetryPeerConnection` message to all peers in the current round. +2. When clients receive `RetryPeerConnection`, they clean up their peer pool and aggregator nodes, then rerun the peer connection phase. +3. If the connection setup still fails after more than `maxConnectionRetry` attempts, the server removes the failed peers from the round peer list. +4. The server sends a `ConnectionFail` message to the failed peers. +5. When a client receives `ConnectionFail`, it disconnects from the server. +6. The remaining clients receive `RetryPeerConnection` and retry the peer connection phase without the failed peers. + +This allows the remaining participants to continue training even when one or more peers fail to establish connections. + +### Model Syncing for Participants Joining in the Middle of Training + +Participants that join in the middle of training need to receive the latest model before they can participate in future rounds. In decentralized learning, peers do not send model weights to the server, so the newcomer must request the latest model from an existing peer. + +The model syncing process is as follows: +1. When a new participant joins in the middle of training, the server marks the participant as having joined mid-training in `NewDecentralizedNodeInfo`. +2. After receiving `NewDecentralizedNodeInfo`, the new client sets a local flag indicating that it needs model syncing. +3. When training begins, if this flag is set, the new client sends a `ModelSyncRequest` message to the server. +4. After receiving `ModelSyncRequest`, the server sends messages as step 5 and 6, using selected model provider information from previous training round. +5. The server sends `SignalModelProvider` to the new participant with information about the provider peer. +6. The server sends `SignalNewPeer` to the provider peer with information about the newly joined peer. +7. Using this signaling information, the new participant and provider peer establish a peer connection. +8. The provider waits until the current aggregation round has finished, then sends the latest model to the new participant using a `SharedModel` message. +9. The new participant receives the model and updates its local model weights. +10. After syncing, the new participant can join subsequent decentralized training rounds. \ No newline at end of file diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 93aaa3c78..81e0f2f56 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -31,7 +31,6 @@ export class DecentralizedClient extends Client<"decentralized"> { // Check if the training round is in progress // Used to get the latest model for model synchronization - #isRoundInTraining = false #roundFinishedPromise?: Promise #resolveRoundFinished?: () => void // contains resolver @@ -83,6 +82,7 @@ export class DecentralizedClient extends Client<"decentralized"> { // and sends the latest model weights. this.server.on(type.SignalNewPeer, async (event) => { if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined') + const roundFinishedPromise = this.#roundFinishedPromise const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) const newcomerConn = syncConnection.get(event.newNode) @@ -93,7 +93,7 @@ export class DecentralizedClient extends Client<"decentralized"> { debug(`Cannot connect to newly joined client [${event.newNode}]`) return } - await this.sendModel(newcomerConn) + await this.sendModel(newcomerConn, roundFinishedPromise) }) // c.f. setupServerCallbacks doc for explanation @@ -186,6 +186,8 @@ export class DecentralizedClient extends Client<"decentralized"> { const latestModel = await this.receiveModel(providerConn) this.modelWeightAccess?.setModelWeight(latestModel) + + this.emit("modelSynced", this.modelWeightAccess?.getModelWeight()) this.#modelSyncNeeded = false } @@ -195,6 +197,10 @@ export class DecentralizedClient extends Client<"decentralized"> { // Store the promise for the current round's aggregation result. // We will await for it to resolve at the end of the round when exchanging weight updates. this.aggregationResult = this.aggregator.getPromiseForAggregation() + + // Do not proceed to local training when minNbOfParticipants condition is not satisfied + await this.waitForParticipantsIfNeeded() + this.saveAndEmit("local training") return Promise.resolve() } @@ -210,6 +216,8 @@ export class DecentralizedClient extends Client<"decentralized"> { while(true){ // Wait until enough participants are available before continuing the round + // Checks minNbOfParticipants requirement for + // when participants disconnect when connection error happens continuously await this.waitForParticipantsIfNeeded() // Create peer-to-peer connections with all peers for the round @@ -228,6 +236,12 @@ export class DecentralizedClient extends Client<"decentralized"> { ]) if (msg.type === type.StartWeightSharing){ + // Generate a promise that resolves when round training finishes + if (this.#roundFinishedPromise === undefined){ + this.#roundFinishedPromise = new Promise((resolve) => { + this.#resolveRoundFinished = resolve + }) + } break } else if (msg.type === type.RetryPeerConnections){ debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`) @@ -245,16 +259,7 @@ export class DecentralizedClient extends Client<"decentralized"> { } } // Exchange weight updates with peers and return aggregated weights - let aggregatedWeight: WeightsContainer - try{ - aggregatedWeight = await this.exchangeWeightUpdates(weights) - } finally { - // Mark the round as finished so that model synchronization can proceed - this.#isRoundInTraining = false - this.#resolveRoundFinished?.() - this.#roundFinishedPromise = undefined - this.#resolveRoundFinished = undefined - } + const aggregatedWeight = await this.exchangeWeightUpdates(weights) return aggregatedWeight } @@ -282,12 +287,6 @@ export class DecentralizedClient extends Client<"decentralized"> { debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); const receivedMessage = await waitMessage(this.server, type.PeersForRound) - this.#isRoundInTraining = true - // Generate a promise that resolves when round training finishes - this.#roundFinishedPromise = new Promise((resolve) => { - this.#resolveRoundFinished = resolve - }) - const peers = Set(receivedMessage.peers) debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray()); @@ -431,9 +430,10 @@ export class DecentralizedClient extends Client<"decentralized"> { * If the current training round is in progress, wait until the round finishes * and receive the latest aggregated model. */ - private async sendModel(newcomerConn: PeerConnection): Promise { - if (this.#isRoundInTraining){ - await this.#roundFinishedPromise + private async sendModel(newcomerConn: PeerConnection, roundFinishedPromise: Promise | undefined): Promise { + // wait until the round finishes to get the latest model + if (roundFinishedPromise !== undefined){ + await roundFinishedPromise } const model = this.modelWeightAccess?.getModelWeight() @@ -450,4 +450,12 @@ export class DecentralizedClient extends Client<"decentralized"> { } newcomerConn.send(message) } + + // Resolve the round finished promise and reset related state + override finishRound(): void{ + // Mark round as finished so that model synchronization can proceed + this.#resolveRoundFinished?.() + this.#roundFinishedPromise = undefined + this.#resolveRoundFinished = undefined + } } diff --git a/discojs/src/default_tasks/cifar10.ts b/discojs/src/default_tasks/cifar10.ts index 3a3537440..de00ef706 100644 --- a/discojs/src/default_tasks/cifar10.ts +++ b/discojs/src/default_tasks/cifar10.ts @@ -47,7 +47,8 @@ export const cifar10: TaskProvider<"image", "decentralized"> = { }, minNbOfParticipants: 3, maxShareValue: 100, - tensorBackend: 'tfjs' + tensorBackend: 'tfjs', + maxConnectionRetry: 3, } }); }, diff --git a/discojs/src/default_tasks/mnist.ts b/discojs/src/default_tasks/mnist.ts index 71b148507..0be9cf1ae 100644 --- a/discojs/src/default_tasks/mnist.ts +++ b/discojs/src/default_tasks/mnist.ts @@ -43,7 +43,8 @@ export const mnist: TaskProvider<"image", "decentralized"> = { }, minNbOfParticipants: 3, maxShareValue: 100, - tensorBackend: 'tfjs' + tensorBackend: 'tfjs', + maxConnectionRetry: 3, } }); }, diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index fa01cdf2e..55ef123ed 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -100,6 +100,9 @@ export namespace TrainingInformation { .object({ scheme: z.literal("decentralized"), aggregationStrategy: z.literal(["byzantine", "mean", "secure"]), + + // Maximum number of retries for connection failures + maxConnectionRetry: z.number().nonnegative().int().default(3), }) .and(nonLocalNetworkSchema), federated: z diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 311ad1d7a..50c173680 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -87,7 +87,8 @@ export interface ModelWeightAccess{ */ export class Disco extends EventEmitter<{ status: RoundStatus; - participants: number + participants: number; + modelSynced: WeightsContainer | undefined; }> { public readonly trainer: Trainer; readonly #client: clients.Client; @@ -149,6 +150,7 @@ export class Disco extends EventEmitter<{ // Simply propagate the training status events emitted by the client this.#client.on("status", (status) => this.emit("status", status)); this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants)); + this.#client.on("modelSynced", (latestWeights) => this.emit("modelSynced", latestWeights)); } /** Train on dataset, yielding logs of every round. */ diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index 68e716bcc..87dada9db 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -141,6 +141,7 @@ export class Trainer { // Update the local weights this.model.weights = networkWeights; + this.#client.finishRound(); } } diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index ade4c8381..03af8d942 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -27,10 +27,14 @@ export class DecentralizedController< // number of connection retries for the training round #connectionRetry = 0 + // Client selected to provide the latest model to peers // joining in the middle of training #providerNode?: client.NodeID + // Set of nodes that are syncing node + #syncingNodes = Set() + handle (ws: WebSocket): void { const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants @@ -53,10 +57,13 @@ export class DecentralizedController< case MessageTypes.ClientConnected: { debug(`peer [%s] joined ${this.task.id}`, shortId) this.connections = this.connections.set(peerId, ws) - + + // If the new peer joins in the middle of training, + // it needs to get the latest model from an existing peer let joinedMidTraining = false if (this.#aggregationRound > 0){ joinedMidTraining = true + this.#syncingNodes = this.#syncingNodes.add(peerId) } // Answer with client id in an NewNodeInfo message @@ -75,6 +82,7 @@ export class DecentralizedController< // Send by peers at the beginning of each training round to notify // the server that they want to join the round case MessageTypes.JoinRound: { + this.#syncingNodes = this.#syncingNodes.delete(peerId) this.#roundPeers = this.#roundPeers.set(peerId, false) break } @@ -150,6 +158,12 @@ export class DecentralizedController< this.#connectFinishedNodes = this.#connectFinishedNodes.delete(peerId) debug("client [%s] left", shortId) + // If this participant was a latest model provider node, + // replace the provider node to another node + if (this.#providerNode === peerId) { + this.#providerNode = this.connections.keySeq().first() + } + // Check if we are already waiting for new participants to join if (this.waitingForMoreParticipants) return // If no, check if we are still above the minimum number of participant required @@ -170,10 +184,16 @@ export class DecentralizedController< private sendPeersForRoundIfNeeded(): void { const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants const nbOfPeersReady = this.#roundPeers.filter(ready => ready).size + // participating peers are connected peers expect the ones that are in process of syncing + const participatingPeers = this.connections.keySeq().toSet().subtract(this.#syncingNodes) + // First check if there are enough participants to start the round // Then check if all peers that wanted to join this round are ready + // All peers that are connected to the server (except for newly joining peers waiting for the latest model) + // are expected to participate in the round + if (nbOfPeersReady < minNbOfParticipants - || nbOfPeersReady != this.#roundPeers.size) return + || nbOfPeersReady != participatingPeers.size) return // Once every peer that joined the round is ready, we can start the round this.#roundPeers.keySeq() .map((id) => { @@ -276,7 +296,7 @@ export class DecentralizedController< // If the number of retries exceeds the threshold, exclude the failed peers from the round // and retry peer connection only with the remaining peers - if (this.#connectionRetry >= 3){ + if (this.#connectionRetry >= this.task.trainingInformation.maxConnectionRetry){ // Exclude the failed peers this.#connectFinishedNodes.forEach((connected, nodeId) => { if (!connected){ diff --git a/server/tests/e2e/decentralized.spec.ts b/server/tests/e2e/decentralized.spec.ts index 0a0fb29af..71f8065b2 100644 --- a/server/tests/e2e/decentralized.spec.ts +++ b/server/tests/e2e/decentralized.spec.ts @@ -1,5 +1,5 @@ import type * as http from "node:http"; -import type { DataType, RoundStatus, Task, TaskProvider } from "@epfml/discojs"; +import type { DataType, RoundStatus, Task, TaskProvider, EpochLogs } from "@epfml/discojs"; import { aggregator as aggregators, client as clients, @@ -27,6 +27,47 @@ async function expectWSToBeClose( expect(l).to.be.closeTo(r, 1e-4); } + +// function from federated.spec.ts +export async function arrayFromAsync(iter: AsyncIterable): Promise { + const ret: T[] = []; + for await (const e of iter) { + // TODO trick to allow other Promises to run + // else one client might progress alone without communicating with others + // will be fixed when client orchestrations in the server is correctly done + await new Promise((resolve) => setTimeout(resolve, 10)); + + ret.push(e); + } + return ret; +} + +// function to check if weights across all participants are close to each other +async function expectAllWSToBeClose( + weights: WeightsContainer[] +): Promise { + const reference = weights[0] + + await Promise.all( + weights.map(async (current) => { + await expectWSToBeClose(reference, current) + }) + ) +} + +const expectWeightsToEqual = ( + a: WeightsContainer, + b: WeightsContainer, +) => { + expect(a.weights.length).to.equal(b.weights.length); + + a.weights.forEach((w, i) => { + expect(Array.from(w.dataSync())).to.deep.equal( + Array.from(b.weights[i].dataSync()), + ); + }); +}; + describe("end-to-end decentralized", { timeout: 50_000 }, () => { let handle: http.Server | undefined; async function startServer( @@ -134,7 +175,147 @@ describe("end-to-end decentralized", { timeout: 50_000 }, () => { await reachConsensus(url, "secure", 3); }); - it("peers emit expected events", { timeout: 100_000 }, async () => { + /** + * Unit tests with 10 participants + */ + // Mean aggregator + it("ten cifar10 users reach consensus with mean aggregation", { timeout: 300_000 }, async () => { + const baseTask = await defaultTasks.cifar10.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "mean", + epochs: 3, + roundDuration: 1, + minNbOfParticipants: 10, + }, + }; + + const url = await startServer({ + ...defaultTasks.cifar10, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadCifar10(); + + const discos = Array.from( + { length: 10 }, + () => new Disco(task, url, { preprocessOnce: true }), + ); + + try{ + const results = await Promise.all( + discos.map(async (disco) => { + const logs = List(await arrayFromAsync(disco.trainByRound(dataset))); + const lastEpoch = logs.last()?.epochs.last(); + if (lastEpoch === undefined) throw new Error("no epoch ran"); + + return [disco.trainer.model.weights, lastEpoch] as [WeightsContainer, EpochLogs]; + }) + ); + + await expectAllWSToBeClose(results.map(([weights])=>weights)); + }finally{ + await Promise.all(discos.map((disco) => disco.close())); + } + }); + + // Byzantine aggregator + it("ten cifar10 users reach consensus with byzantine aggregation", { timeout: 300_000 }, async () => { + const baseTask = await defaultTasks.cifar10.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "byzantine", + epochs: 3, + roundDuration: 1, + minNbOfParticipants: 10, + privacy: { + byzantineFaultTolerance: { + clippingRadius: 10, + maxIterations: 1, + beta: 0.9, + }, + }, + }, + }; + + const url = await startServer({ + ...defaultTasks.cifar10, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadCifar10(); + + const discos = Array.from( + { length: 10 }, + () => new Disco(task, url, { preprocessOnce: true }), + ); + + try{ + const results = await Promise.all( + discos.map(async (disco) => { + const logs = List(await arrayFromAsync(disco.trainByRound(dataset))); + const lastEpoch = logs.last()?.epochs.last(); + if (lastEpoch === undefined) throw new Error("no epoch ran"); + + return [disco.trainer.model.weights, lastEpoch] as [WeightsContainer, EpochLogs]; + }) + ); + + await expectAllWSToBeClose(results.map(([weights])=>weights)); + }finally{ + await Promise.all(discos.map((disco) => disco.close())); + } + }); + + // Secure aggregator + it("ten cifar10 users reach consensus with secure aggregation", { timeout: 500_000 }, async () => { + const baseTask = await defaultTasks.cifar10.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "secure", + epochs: 10, + roundDuration: 1, + minNbOfParticipants: 10, + maxShareValue: 100, + }, + }; + + const url = await startServer({ + ...defaultTasks.cifar10, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadCifar10(); + + const discos = Array.from( + { length: 10 }, + () => new Disco(task, url, { preprocessOnce: true }), + ); + + try{ + const results = await Promise.all( + discos.map(async (disco) => { + const logs = List(await arrayFromAsync(disco.trainByRound(dataset))); + const lastEpoch = logs.last()?.epochs.last(); + if (lastEpoch === undefined) throw new Error("no epoch ran"); + + return [disco.trainer.model.weights, lastEpoch] as [WeightsContainer, EpochLogs]; + }) + ); + + await expectAllWSToBeClose(results.map(([weights])=>weights)); + }finally{ + await Promise.all(discos.map((disco) => disco.close())); + } + }); + + it("peers emit expected events", { timeout: 300_000 }, async () => { const baseTask = await defaultTasks.lusCovid.getTask(); const task: Task<"image", "decentralized"> = { ...baseTask, @@ -144,6 +325,7 @@ describe("end-to-end decentralized", { timeout: 50_000 }, () => { aggregationStrategy: "mean", roundDuration: 1, minNbOfParticipants: 2, + maxConnectionRetry: 3, }, }; const url = await startServer({ @@ -326,4 +508,131 @@ describe("end-to-end decentralized", { timeout: 50_000 }, () => { await discoUser3.close() }); + + /** + * We test if the latest model syncing is working when new participant + * joins in the middle of the training (when the round > 0). + * + * The test workflow + * 1. Start User1 and User2 starts training + * 2. Let them complete at least one aggregation round + * 3. Start User3 when aggregationRound is larger than 0 + * 4. When User3 starts training, model synchronization should be triggered first + * 5. Compare User3's model weights with User1/User2's latest model weights + */ + it("Model Syncing when new participant joins in the middle of the training", { timeout: 200_000 }, async () => { + const baseTask = await defaultTasks.lusCovid.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "mean", + roundDuration: 1, + minNbOfParticipants: 2, + maxConnectionRetry: 3, + }, + }; + + const url = await startServer({ + ...defaultTasks.lusCovid, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadLusCOVID(); + + const discoUser1 = new Disco(task, url, { preprocessOnce: true }); + const discoUser2 = new Disco(task, url, { preprocessOnce: true }); + const discoUser3 = new Disco(task, url, { preprocessOnce: true }); + + try { + const generatorUser1 = discoUser1.trainByRound(dataset); + const generatorUser2 = discoUser2.trainByRound(dataset); + + await Promise.all([ + generatorUser1.next(), + generatorUser2.next(), + ]); + + await Promise.all([ + generatorUser1.next(), + generatorUser2.next(), + ]); + + // Existing participants should already have the same aggregated model. + await expectWSToBeClose( + discoUser1.trainer.model.weights, + discoUser2.trainer.model.weights, + ); + + const waitForModelSynced = Promise.race([ + new Promise<{ + syncedWeights: WeightsContainer; + providerWeightsUser1: WeightsContainer; + providerWeightsUser2: WeightsContainer; + }>((resolve) => { + discoUser3.on("modelSynced", (weights) => { + if (weights !== undefined) { + resolve({ + syncedWeights: weights, + providerWeightsUser1: new WeightsContainer( + discoUser1.trainer.model.weights.weights.map((w) => w.clone()), + ), + providerWeightsUser2: new WeightsContainer( + discoUser2.trainer.model.weights.weights.map((w) => w.clone()), + ), + }); + } + }); + }), + new Promise((_, reject) => + setTimeout( + () => reject(new Error("Timed out waiting for modelSynced")), + 60_000, + ), + ), + ]); + + const generatorUser3 = discoUser3.trainByRound(dataset); + const user3RoundPromise = generatorUser3.next(); + + await new Promise((resolve) => setTimeout(resolve, 5_000)); + + // The newcomer may ask for synchronization while existing participants are + // already in the next local round. Progress peers until the provider sends the latest model. + for (let attempt = 0; attempt < 5; attempt++) { + const synced = await Promise.race([ + waitForModelSynced.then(() => true), + new Promise((resolve) => + setTimeout(() => resolve(false), 100), + ), + ]); + + if (synced) break; + + await Promise.all([ + generatorUser1.next(), + generatorUser2.next(), + ]); + } + + const { + syncedWeights, + providerWeightsUser1, + providerWeightsUser2, + } = await waitForModelSynced; + + try { + expectWeightsToEqual(syncedWeights, providerWeightsUser1); + } catch { + expectWeightsToEqual(syncedWeights, providerWeightsUser2); + } + + const user3Round = await user3RoundPromise; + expect(user3Round.done).to.be.false; + } finally { + await discoUser1.close(); + await discoUser2.close(); + await discoUser3.close(); + } + }); }) diff --git a/webapp/src/components/task_creation_form/TaskCreationForm.vue b/webapp/src/components/task_creation_form/TaskCreationForm.vue index 5ff1d705a..fb875e7ab 100644 --- a/webapp/src/components/task_creation_form/TaskCreationForm.vue +++ b/webapp/src/components/task_creation_form/TaskCreationForm.vue @@ -348,6 +348,25 @@ + +
+ Maximum number of connection retries before disconnecting failed participants. +
+ + +
+ + Date: Tue, 2 Jun 2026 23:47:55 +0200 Subject: [PATCH 07/10] Fix promise returning problem during SignalNewPeer message processing --- .../decentralized/decentralized_client.ts | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 81e0f2f56..f37fae592 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -44,6 +44,24 @@ export class DecentralizedClient extends Client<"decentralized"> { // Emits the `participants` event this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size } + + // Used by model provider peer during model syncing + private async handleSignalNewPeer(event: any): Promise { + if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined') + const roundFinishedPromise = this.#roundFinishedPromise + const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) + + const newcomerConn = syncConnection.get(event.newNode) + + if (newcomerConn === undefined){ + // if connection with newly joining client fails, print debug message + // and return + debug(`Cannot connect to newly joined client [${event.newNode}]`) + return + } + + await this.sendModel(newcomerConn, roundFinishedPromise) + } /** * Public method called by disco.ts when starting training. This method sends @@ -80,20 +98,8 @@ export class DecentralizedClient extends Client<"decentralized"> { // Listen if the client is selected as a model provider node for a newly joining client. // Upon receiving the signal, this client establishes a connection with the newcomer // and sends the latest model weights. - this.server.on(type.SignalNewPeer, async (event) => { - if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined') - const roundFinishedPromise = this.#roundFinishedPromise - const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) - - const newcomerConn = syncConnection.get(event.newNode) - - if (newcomerConn === undefined){ - // if connection with newly joining client fails, print debug message - // and return - debug(`Cannot connect to newly joined client [${event.newNode}]`) - return - } - await this.sendModel(newcomerConn, roundFinishedPromise) + this.server.on(type.SignalNewPeer, (event) => { + void this.handleSignalNewPeer(event) }) // c.f. setupServerCallbacks doc for explanation From 14b2a4b18e990cc8b2e20ee2a2b2b34f89d538bf Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Wed, 3 Jun 2026 00:44:47 +0200 Subject: [PATCH 08/10] Specify event type --- discojs/src/client/decentralized/decentralized_client.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index f37fae592..3840e7273 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -5,7 +5,7 @@ import type { DataType, Model, WeightsContainer } from "../../index.js"; import { serialization } from "../../index.js"; import { Client, shortenId } from '../client.js' import { type NodeID } from '../index.js' -import { type, type ClientConnected } from '../messages.js' +import { type, type ClientConnected, NarrowMessage } from '../messages.js' import { timeout } from '../utils.js' import { WebSocketServer, waitMessage, type PeerConnection, waitMessageWithTimeout } from '../event_connection.js' import { PeerPool } from './peer_pool.js' @@ -46,7 +46,7 @@ export class DecentralizedClient extends Client<"decentralized"> { } // Used by model provider peer during model syncing - private async handleSignalNewPeer(event: any): Promise { + private async handleSignalNewPeer(event: NarrowMessage): Promise { if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined') const roundFinishedPromise = this.#roundFinishedPromise const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) From 3bfce91dc9672e104cdd808aa9918cfd87b5c690 Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Wed, 17 Jun 2026 20:21:26 +0200 Subject: [PATCH 09/10] Add learning parameters reset & Move documentations --- discojs/src/aggregator/aggregator.ts | 13 ++++++++++ .../decentralized/decentralized_client.ts | 4 ++- .../README.md => docs/DECENTRALIZED.md | 1 + docs/FEDERATED.md | 19 ++++++++++++++ .../controllers/decentralized_controller.ts | 26 +++++++++++++++++++ .../src/controllers/federated_controller.ts | 19 +++++++++++--- server/src/controllers/training_controller.ts | 10 +++++++ 7 files changed, 88 insertions(+), 4 deletions(-) rename discojs/src/client/decentralized/README.md => docs/DECENTRALIZED.md (95%) create mode 100644 docs/FEDERATED.md diff --git a/discojs/src/aggregator/aggregator.ts b/discojs/src/aggregator/aggregator.ts index f21aff2aa..dc9f03712 100644 --- a/discojs/src/aggregator/aggregator.ts +++ b/discojs/src/aggregator/aggregator.ts @@ -196,6 +196,19 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon this._nodes = this._nodes.delete(nodeId) } + /** + * Dispose the contributions to clean tensor memory + */ + dispose(): void { + this.contributions.forEach((roundContributions) => { + roundContributions.forEach((contribution) => { + contribution.dispose() + }) + }) + + this.contributions = Map() + } + /** * Overwrites the current set of active nodes with the given one. A node represents * an active neighbor peer/client within the network, whom we are communicating with diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 3840e7273..b616f99f8 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -47,7 +47,9 @@ export class DecentralizedClient extends Client<"decentralized"> { // Used by model provider peer during model syncing private async handleSignalNewPeer(event: NarrowMessage): Promise { - if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined') + if (this.#pool === undefined){ + throw new Error('received signal about new peer but peer pool is undefined') + } const roundFinishedPromise = this.#roundFinishedPromise const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) diff --git a/discojs/src/client/decentralized/README.md b/docs/DECENTRALIZED.md similarity index 95% rename from discojs/src/client/decentralized/README.md rename to docs/DECENTRALIZED.md index 07d581603..fa3c74b87 100644 --- a/discojs/src/client/decentralized/README.md +++ b/docs/DECENTRALIZED.md @@ -1,5 +1,6 @@ # Peer Connection in Decentralized Learning This document describes how peer connections for decentralized learning are established. +Relevant code can be found in [decentralized controller](../server/src/controllers/decentralized_controller.ts) and [decentralized client](../discojs/src/client/decentralized/decentralized_client.ts). Peer connections for decentralized learning are coordinated by the server. However, model weight updates are shared only between peers. diff --git a/docs/FEDERATED.md b/docs/FEDERATED.md new file mode 100644 index 000000000..347d19d1e --- /dev/null +++ b/docs/FEDERATED.md @@ -0,0 +1,19 @@ +# Connections and Aggregations in Federated Learning +This documentation describes how connections between the server and clients are established, how model updates are aggregated and how updated weights are distributed. + +## Connecting to the Server +Clients participating in federated learning connect directly to the server. The server acts as the central coordination and aggregation point. Therefore, the clients only need to establish a connection with the server. + +When a client connects, the server assigns it a client ID and sends it the latest available global model weights and training information. The client initializes its local model with these weights and can begin training on its local model. + +## Aggregating Model Updates +After finishing local training for a round, each client sends its model update to the server using a `SendPayload` message. This message contains the client's current round number so that the server can synchronize model weights aggregation. + +The server checks the weight update contribution and adds it to the aggregator. When the required number of contributions has been received, the aggregator combines the client updates according to the aggregation mode and produces a new global model update. + +The server then sends the aggregated result to each participants using a `ReceiveServerPayload` message. Each client updates its local model to the received weights and proceeds to the next training round. + +After every successful aggregation, the server also stores the resulting weights as the latest global model weights. + +## Clients Joining During Training +When a new client joins an ongoing training round, the server sends it the latest available global model weights. The new client can then begin local training from the latest globally aggregated model. \ No newline at end of file diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index 03af8d942..ad984ab13 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -156,6 +156,14 @@ export class DecentralizedController< this.connections = this.connections.delete(peerId) this.#roundPeers = this.#roundPeers.delete(peerId) this.#connectFinishedNodes = this.#connectFinishedNodes.delete(peerId) + + // Reset the training session when all participants leave + if (this.connections.size === 0){ + debug("All participants left. Resetting decentralized training session") + this.reset() + return + } + debug("client [%s] left", shortId) // If this participant was a latest model provider node, @@ -176,6 +184,24 @@ export class DecentralizedController< this.sendWaitForMoreParticipantsMsg() }) } + + reset(): void { + this.resetConnectionState() + + this.#roundPeers = Map() + this.#connectFinishedNodes = Map() + this.#aggregationRound = 0 + this.#connectionRetry = 0 + this.#providerNode = undefined + this.#syncingNodes = Set() + + // Reset the timeout + if (this.#timeout !== undefined){ + clearTimeout(this.#timeout) + this.#timeout = undefined + } + } + /** * Check if we have enough participants to start the training * and if all peers that joined the round are ready to exchange weight updates diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index 1e52fe539..1d75f18dc 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -151,9 +151,8 @@ export class FederatedController extends TrainingController< // Reset the training session when all participants left if (this.connections.size === 0) { - debug("All participants left. Resetting the training session") - this.#aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') - this.#latestGlobalWeights = this.initialWeights + debug("All participants left. Resetting federated training session") + this.reset() } // Check if we dropped below the minimum number of participant required @@ -166,4 +165,18 @@ export class FederatedController extends TrainingController< this.sendWaitForMoreParticipantsMsg() }) } + + reset(): void { + this.resetConnectionState() + + this.#aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') + this.#latestGlobalWeights = this.initialWeights + + // Since we replaced aggregator, we also need to register new aggregation listener + this.#aggregator.on("aggregation", async (weightUpdate) => { + this.#latestGlobalWeights = await serialization.weights.encode(weightUpdate) + }) + + this.#aggregator.dispose() + } } diff --git a/server/src/controllers/training_controller.ts b/server/src/controllers/training_controller.ts index bbee678ec..bff46247d 100644 --- a/server/src/controllers/training_controller.ts +++ b/server/src/controllers/training_controller.ts @@ -44,6 +44,16 @@ export abstract class TrainingController< abstract handle( ws: WebSocket ): void + + // Reset the controller state + abstract reset(): void + + // Reset the peer connection state + // Used when the training is finished + protected resetConnectionState(): void { + this.waitingForMoreParticipants = true + this.connections = Map() + } /** * If enough participants joined, notifies them that the training can start/resume From 62196b0f899a9bd6f3dc4bbac849005ef6a1f30a Mon Sep 17 00:00:00 2001 From: ahzero7d1 Date: Wed, 17 Jun 2026 21:38:02 +0200 Subject: [PATCH 10/10] Fix decentralized unit test --- server/tests/e2e/decentralized.spec.ts | 327 ++++++++++++------------- 1 file changed, 153 insertions(+), 174 deletions(-) diff --git a/server/tests/e2e/decentralized.spec.ts b/server/tests/e2e/decentralized.spec.ts index 71f8065b2..6ecf7f337 100644 --- a/server/tests/e2e/decentralized.spec.ts +++ b/server/tests/e2e/decentralized.spec.ts @@ -27,7 +27,6 @@ async function expectWSToBeClose( expect(l).to.be.closeTo(r, 1e-4); } - // function from federated.spec.ts export async function arrayFromAsync(iter: AsyncIterable): Promise { const ret: T[] = []; @@ -315,198 +314,178 @@ describe("end-to-end decentralized", { timeout: 50_000 }, () => { } }); - it("peers emit expected events", { timeout: 300_000 }, async () => { - const baseTask = await defaultTasks.lusCovid.getTask(); - const task: Task<"image", "decentralized"> = { - ...baseTask, - trainingInformation: { - ...baseTask.trainingInformation, - scheme: "decentralized", - aggregationStrategy: "mean", - roundDuration: 1, - minNbOfParticipants: 2, + // syncs model after participants drop below minNbOfParticipants and newcomers join with mean aggregator + it("peers emit expected events", { timeout: 150_000 }, async () => { + const baseTask = await defaultTasks.lusCovid.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "mean", + roundDuration: 1, + minNbOfParticipants: 2, maxConnectionRetry: 3, - }, - }; - const url = await startServer({ - ...defaultTasks.lusCovid, - getTask: () => Promise.resolve(task), - }); - const dataset = await datasets.loadLusCOVID(); + }, + }; + + const url = await startServer({ + ...defaultTasks.lusCovid, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadLusCOVID(); /** - * Then at each round (each call to `disco.trainByRound`) the event cycle is: - * a) During onRoundBeingCommunication, - * 1. the peer notifies the server that they want to join the next round - * 2. finishes by updating the status to "local training" - * (without waiting for a server answer) - * b) local training (the status remains "local training") - * c) During onRoundEndCommunication - * 1. the peer notifies the server that they are ready to share weights - * set status to "connecting to peers" - * 2. wait for the server to answer with the current round's peers list - * this is where the nb of participants is updated - * 3. establish peer-to-peer connections - * 4. set status to "updating model" and exchange weight updates - * - * Given this, it is important to note that calling disco.trainByRound().next() - * for the first time will perform a) and then b) where it stops and yields the round logs. - * Thus, c) isn't called and the weight sharing is not performed during this call to next(). - * Calling next() again will then run c), as well as a) and b) again. - * - * In this test the timeline is: - * - User 1 joins the task by themselves + * Test timeline looks like this: + * - User 1 joins the task * - User 2 joins - * - User 1 leaves - * - User 3 joins + * - User 1 leaves (Since minNbOfParticipants condition is not satisfied, the training stops) + * - User 3 joins (User 3 gets the latest model from User 2 and start local training from that model) * - User 2 & 3 leave */ - /* USER 1 JOINS */ - const discoUser1 = new Disco(task, url, { preprocessOnce: true }); + const discoUser2 = new Disco(task, url, { preprocessOnce: true }); + + // Register listeners for user1 and user2 events const statusUser1 = new Queue(); const nbParticipantsUser1 = new Queue(); - discoUser1.on("status", status => { statusUser1.put(status) }) - discoUser1.on("participants", (participants) => { nbParticipantsUser1.put(participants) }) - const generatorUser1 = discoUser1.trainByRound(dataset) - - // Have User 1 join the task and train locally for one round - const logUser1Round1 = await generatorUser1.next() - expect(logUser1Round1.done).to.be.false - // User 1 did a) and b) so their status should be Training - expect(await statusUser1.next()).equal("local training") - expect(await nbParticipantsUser1.next()).equal(1) - - if (logUser1Round1.done) - throw new Error("User 1 finished training at the 1st round") - // participant list not updated yet (updated at step c)) - expect((logUser1Round1.value).participants).equal(1) - - // Calling next() a 2nd time makes User 1 go to c) where the peer should - // stay stuck awaiting until another participant joins - const logUser1Round2Promise = generatorUser1.next() - expect(await statusUser1.next()).equal("connecting to peers") // tries to connect to peers - expect(await statusUser1.next()).equal("not enough participants") // but has to wait for more participants + const statusUser2 = new Queue(); + const nbParticipantsUser2 = new Queue(); + discoUser1.on("status", (status) => statusUser1.put(status)); + discoUser1.on("participants", (participants) => nbParticipantsUser1.put(participants)); + discoUser2.on("status", (status) => statusUser2.put(status)); + discoUser2.on("participants", (participants) => nbParticipantsUser2.put(participants)); + + let user2Closed = false; + + const generatorUser1 = discoUser1.trainByRound(dataset); + const generatorUser2 = discoUser2.trainByRound(dataset); + + /* ROUND 1 */ + /* USER 1 JOINS */ + const round1User1Promise = generatorUser1.next(); + expect(await statusUser1.next()).equal("not enough participants"); + // We expect only one participant + expect(await nbParticipantsUser1.next()).equal(1); /* USER 2 JOINS */ + /* minNbOfParticipants condition satisfied, local training starts */ + const round1User2Promise = generatorUser2.next(); + await Promise.all([round1User1Promise, round1User2Promise]); + + expect(await statusUser2.next()).equal("local training"); + expect(await statusUser1.next()).equal("local training"); + expect(await nbParticipantsUser1.next()).equal(2); + expect(await nbParticipantsUser2.next()).equal(2); + + /* ROUND 2 - first weight exchange */ + await Promise.all([ + generatorUser1.next(), + generatorUser2.next(), + ]); + + // Both users did connecting -> updating model -> local training + expect(await statusUser1.next()).equal("connecting to peers"); + expect(await statusUser1.next()).equal("updating model"); + expect(await statusUser1.next()).equal("local training"); + expect(await statusUser2.next()).equal("connecting to peers"); + expect(await statusUser2.next()).equal("updating model"); + expect(await statusUser2.next()).equal("local training"); + expect(await nbParticipantsUser1.next()).equal(2); + expect(await nbParticipantsUser2.next()).equal(2); + + // Weights should have converged after exchanging updates + await expectWSToBeClose( + discoUser1.trainer.model.weights, + discoUser2.trainer.model.weights, + ); - const discoUser2 = new Disco(task, url, { preprocessOnce: true }); - const statusUser2 = new Queue(); - const nbParticipantsUser2 = new Queue(); - discoUser2.on("status", status => { statusUser2.put(status) }) - discoUser2.on("participants", (participants) => { nbParticipantsUser2.put(participants) }) - const generatorUser2 = discoUser2.trainByRound(dataset) - - // Have User 2 join the task and train for one round - const logUser2Round1 = await generatorUser2.next() - expect(logUser2Round1.done).to.be.false - if (logUser2Round1.done) - throw new Error("User 2 finished training at the 1st round") - // round payload should contain the number of participants - expect((logUser2Round1.value).participants).equal(2) - expect(await nbParticipantsUser2.next()).equal(2) - // Receive the EnoughParticipants message with the participants - expect(await nbParticipantsUser1.next()).equal(2) - // User 2 did a) and b) - expect(await statusUser2.next()).equal("local training") - // User 1 is still in c) now waiting for user 2 to be ready to exchange weight updates - expect(await statusUser1.next()).equal("connecting to peers") - - /* ROUND 2 */ - - // The server should answer with the round's peers list. - // Peers then exchange updates and then start training locally with the new weights - const logUser2Round2 = await generatorUser2.next() - const logUser1Round2 = await logUser1Round2Promise // the promise can resolve now - expect(logUser1Round2.done).to.be.false - expect(logUser2Round2.done).to.be.false - if (logUser1Round2.done || logUser2Round2.done) - throw new Error("User 1 or 2 finished training at the 2nd round") - // nb of participants should now be updated - expect((logUser1Round2.value).participants).equal(2) - expect((logUser2Round2.value).participants).equal(2) - expect(await nbParticipantsUser2.next()).equal(2) - expect(await nbParticipantsUser1.next()).equal(2) - // User 1 and 2 did c), a) and b) - expect(await statusUser1.next()).equal("updating model") // second to last - expect(await statusUser1.next()).equal("local training") - - expect(await statusUser2.next()).equal("connecting to peers") // back to connecting when user 1 joins - expect(await statusUser2.next()).equal("updating model") - expect(await statusUser2.next()).equal("local training") - - /* USER 1 LEAVES */ - - await discoUser1.close() - // Disconnect updates the number of participants - expect(await nbParticipantsUser1.next()).equal(1) - // User 2 receives the WaitingForMoreParticipants message - expect(await nbParticipantsUser2.next()).equal(1) - // server notifies user 2 to wait - expect(await statusUser2.next()).equal("not enough participants") - // Make user 2 go to c) - const logUser2Round3Promise = generatorUser2.next() - // await new Promise((res, _) => setTimeout(res, statusUpdateTime)) // Wait some time for the status to update - // starts c) and waits for user 3 to join - expect(await statusUser2.next()).equal("connecting to peers") - expect(await statusUser2.next()).equal("not enough participants") + /* USER 2 LEAVES */ + + // ROUND3 starts for User 1 before closing User 2, so User 1 + // enters onRoundEndCommunication and emits "connecting to peers" + const user1WaitingPromise = generatorUser1.next(); + expect(await statusUser1.next()).equal("connecting to peers"); + + await discoUser2.close(); + user2Closed = true; + + // Check if User 1 got a signal that there is not enough participants + expect(await nbParticipantsUser1.next()).equal(1); + expect(await statusUser1.next()).equal("not enough participants"); + + // Snapshot User 1's weights, which will be shared to User 3 for model syncing + const latestWeights = new WeightsContainer( + discoUser1.trainer.model.weights.weights.map((w) => w.clone()), + ); /* USER 3 JOINS */ - // Create User 3 + // Create User 3 and register event listeners const discoUser3 = new Disco(task, url, { preprocessOnce: true }); const statusUser3 = new Queue(); const nbParticipantsUser3 = new Queue(); - discoUser3.on("status", status => { statusUser3.put(status) }) - discoUser3.on("participants", (participants) => { nbParticipantsUser3.put(participants) }) - const generatorUser3 = discoUser3.trainByRound(dataset) - - // User 3 joins mid-training and trains one local round - const logUser3Round1 = await generatorUser3.next() - expect(logUser3Round1.done).to.be.false - if (logUser3Round1.done) - throw new Error("User 3 finished training at the 1st round") - expect((logUser3Round1.value).participants).equal(2) - expect(await nbParticipantsUser3.next()).equal(2) - // User 2 receives the EnoughParticipants message - // User 2 is still in c) waiting for user 3 to share their local update - expect(await nbParticipantsUser2.next()).equal(2) - - // User 3 did a) and b) - expect(await statusUser3.next()).equal("local training") - // User 2 is still in c) waiting for user 3 to be ready to exchange waits - expect(await statusUser2.next()).equal("connecting to peers") - - /* ROUND 3 */ - - // User 3 notifies the server that they are ready to exchange waits - // then user 2 and 3 exchange weight updates - const logUser3Round3 = await generatorUser3.next() - const logUser2Round3 = await logUser2Round3Promise // the promise can resolve now - if (logUser3Round3.done || logUser2Round3.done) - throw new Error("User 1 or 2 finished training at the 3rd round") - - expect(logUser2Round3.value.participants).equal(2) - expect(logUser3Round3.value.participants).equal(2) - expect(await nbParticipantsUser3.next()).equal(2) - expect(await nbParticipantsUser2.next()).equal(2) - - // both user 2 and 3 did c), a) and are now in b) - expect(await statusUser2.next()).equal("updating model") - expect(await statusUser2.next()).equal("local training") - - expect(await statusUser3.next()).equal("connecting to peers") - expect(await statusUser3.next()).equal("updating model") - expect(await statusUser3.next()).equal("local training") - - /* USER 2 AND 3 LEAVE */ - - await discoUser2.close() - expect(await statusUser3.next()).equal("not enough participants") - expect(await nbParticipantsUser3.next()).equal(1) - - await discoUser3.close() + discoUser3.on("status", (status) => statusUser3.put(status)); + discoUser3.on("participants", (participants) => nbParticipantsUser3.put(participants)); + + const waitForUser3ModelSynced = new Promise((resolve) => { + discoUser3.on("modelSynced", (weights) => { + if (weights !== undefined) resolve(weights); + }); + }); + + const generatorUser3 = discoUser3.trainByRound(dataset); + + /* ROUND 3 */ + /* User 3's first round */ + const user3Round1 = await generatorUser3.next(); + expect(user3Round1.done).to.be.false; + expect(user3Round1.value.participants).equal(2); + + // User 3's model should have been synced to User 1's weights before training + const user3SyncedWeights = await waitForUser3ModelSynced; + await expectWSToBeClose(user3SyncedWeights, latestWeights); + + // User 3 did onRoundBeginCommunication and local training + expect(await statusUser3.next()).equal("local training"); + expect(await nbParticipantsUser3.next()).equal(2); + // User 1 learns User 3 joined and is still in onRoundEndCommunication waiting for User 3 to be ready + expect(await nbParticipantsUser1.next()).equal(2); + expect(await statusUser1.next()).equal("connecting to peers"); + + /* ROUND 4 */ + /* first weight exchange between User 1 and User 3 */ + const [user1Round, user3Round2] = await Promise.all([ + user1WaitingPromise, + generatorUser3.next(), + ]); + expect(user1Round.done).to.be.false; + expect(user3Round2.done).to.be.false; + expect(user1Round.value.participants).equal(2); + expect(user3Round2.value.participants).equal(2); + + // Both users did onRoundEndCommunication, onRoundBeginCommunication, and local training + expect(await statusUser1.next()).equal("updating model"); + expect(await statusUser1.next()).equal("not enough participants"); + expect(await statusUser1.next()).equal("local training"); + expect(await nbParticipantsUser1.next()).equal(2); + + expect(await statusUser3.next()).equal("connecting to peers"); + expect(await statusUser3.next()).equal("updating model"); + expect(await statusUser3.next()).equal("local training"); + expect(await nbParticipantsUser3.next()).equal(2); + + + // Weights should have converged between User 1 and User 3 after the exchange + await expectWSToBeClose( + discoUser1.trainer.model.weights, + discoUser3.trainer.model.weights, + ); + + await discoUser3.close(); + await discoUser1.close().catch(() => {}); + if (!user2Closed) await discoUser2.close().catch(() => {}); }); /**