-
Notifications
You must be signed in to change notification settings - Fork 32
Fix decentralized clients connection issue #1110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
2900c65
dd22a7d
7d0e634
947d2c6
18e777e
b443f53
9c7c488
09ced04
14b2a4b
3bfce91
62196b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ import type { DataType, Model, WeightsContainer } from "../../index.js"; | |
| import { serialization } from "../../index.js"; | ||
| import { Client, shortenId } from '../client.js' | ||
| import { type NodeID } from '../index.js' | ||
| import { type, type ClientConnected } from '../messages.js' | ||
| import { type, type ClientConnected, NarrowMessage } from '../messages.js' | ||
| import { timeout } from '../utils.js' | ||
| import { WebSocketServer, waitMessage, type PeerConnection, waitMessageWithTimeout } from '../event_connection.js' | ||
| import { PeerPool } from './peer_pool.js' | ||
|
|
@@ -26,6 +26,14 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| #pool?: PeerPool | ||
| #connections?: Map<NodeID, PeerConnection> | ||
|
|
||
| // Flag if this model requires model synchronization | ||
| #modelSyncNeeded?: boolean | ||
|
|
||
| // Check if the training round is in progress | ||
| // Used to get the latest model for model synchronization | ||
| #roundFinishedPromise?: Promise<void> | ||
| #resolveRoundFinished?: () => void // contains resolver | ||
|
|
||
| // Used to handle timeouts and promise resolving after calling disconnect | ||
| private get isDisconnected() : boolean { | ||
| return this._server === undefined | ||
|
|
@@ -36,6 +44,26 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| // Emits the `participants` event | ||
| this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size | ||
| } | ||
|
|
||
| // Used by model provider peer during model syncing | ||
| private async handleSignalNewPeer(event: NarrowMessage<type.SignalNewPeer>): Promise<void> { | ||
| if (this.#pool === undefined){ | ||
| throw new Error('received signal about new peer but peer pool is undefined') | ||
| } | ||
| const roundFinishedPromise = this.#roundFinishedPromise | ||
| const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) | ||
|
|
||
| const newcomerConn = syncConnection.get(event.newNode) | ||
|
|
||
| if (newcomerConn === undefined){ | ||
| // if connection with newly joining client fails, print debug message | ||
| // and return | ||
| debug(`Cannot connect to newly joined client [${event.newNode}]`) | ||
| return | ||
| } | ||
|
|
||
| await this.sendModel(newcomerConn, roundFinishedPromise) | ||
| } | ||
|
|
||
| /** | ||
| * Public method called by disco.ts when starting training. This method sends | ||
|
|
@@ -69,6 +97,13 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| this.#pool.signal(event.peer, event.signal) | ||
| }) | ||
|
|
||
| // Listen if the client is selected as a model provider node for a newly joining client. | ||
| // Upon receiving the signal, this client establishes a connection with the newcomer | ||
| // and sends the latest model weights. | ||
| this.server.on(type.SignalNewPeer, (event) => { | ||
| void this.handleSignalNewPeer(event) | ||
| }) | ||
|
|
||
| // c.f. setupServerCallbacks doc for explanation | ||
| let receivedEnoughParticipants = false | ||
| this.setupServerCallbacks(() => receivedEnoughParticipants = true) | ||
|
|
@@ -79,8 +114,9 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| this.server.send(msg) | ||
|
|
||
| const { id, waitForMoreParticipants, | ||
| nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) | ||
|
|
||
| nbOfParticipants, joinedMidTraining } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) | ||
|
|
||
| this.#modelSyncNeeded = joinedMidTraining | ||
| this.nbOfParticipants = nbOfParticipants | ||
|
|
||
|
|
||
|
|
@@ -129,14 +165,50 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| * When connected, one peer creates a promise for every other peer's weight update | ||
| * and waits for it to resolve. | ||
| * | ||
| * If a client joined the training after the first round, | ||
| * model syncing happens first to get the latest model. | ||
| */ | ||
| override async onRoundBeginCommunication(): Promise<void> { | ||
| if (this.#modelSyncNeeded) { | ||
| // 1. If model sync is needed, send server a request | ||
| this.server.send({ type: type.ModelSyncRequest }) | ||
|
|
||
| // 2. Get the provider information from the server | ||
| const providerInfo = await waitMessageWithTimeout(this.server, type.SignalModelProvider, 30_000, "Timeout while waiting for the latest model provider") | ||
|
|
||
| if (this.#pool === undefined) { | ||
| throw new Error('peer pool is undefined, make sure to call `client.connect()` first') | ||
| } | ||
|
|
||
| // 3. Connect with model provider client and get the latest model | ||
| const syncConnection = await this.#pool.getPeers( | ||
| Set([providerInfo.providerNode]), | ||
| this.server, | ||
| ()=>{} | ||
| ) | ||
| const providerConn = syncConnection.get(providerInfo.providerNode) | ||
|
|
||
| if (providerConn === undefined){ | ||
| throw new Error("The latest model provider is not connected") | ||
| } | ||
|
|
||
| const latestModel = await this.receiveModel(providerConn) | ||
| this.modelWeightAccess?.setModelWeight(latestModel) | ||
|
|
||
| this.emit("modelSynced", this.modelWeightAccess?.getModelWeight()) | ||
| this.#modelSyncNeeded = false | ||
| } | ||
|
|
||
| // Notify the server we want to join the next round so that the server | ||
| // waits for us to be ready before sending the list of peers for the round | ||
| this.server.send({ type: type.JoinRound }) | ||
| // Store the promise for the current round's aggregation result. | ||
| // We will await for it to resolve at the end of the round when exchanging weight updates. | ||
| this.aggregationResult = this.aggregator.getPromiseForAggregation() | ||
|
|
||
| // Do not proceed to local training when minNbOfParticipants condition is not satisfied | ||
| await this.waitForParticipantsIfNeeded() | ||
|
|
||
| this.saveAndEmit("local training") | ||
| return Promise.resolve() | ||
| } | ||
|
|
@@ -149,11 +221,55 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| // Once enough new participants join we can display the previous status again | ||
| this.saveAndEmit("connecting to peers") | ||
| // First we check if we are waiting for more participants before sending our weight update | ||
| await this.waitForParticipantsIfNeeded() | ||
| // Create peer-to-peer connections with all peers for the round | ||
| await this.establishPeerConnections() | ||
|
|
||
| while(true){ | ||
| // Wait until enough participants are available before continuing the round | ||
| // Checks minNbOfParticipants requirement for | ||
| // when participants disconnect when connection error happens continuously | ||
| await this.waitForParticipantsIfNeeded() | ||
|
|
||
| // Create peer-to-peer connections with all peers for the round | ||
| await this.establishPeerConnections() | ||
|
|
||
| // Wait for connection related messages from the server before exchanging weight updates | ||
| // (1) If the client receives a StartWeightSharing message, it proceeds to weight update exchange | ||
| // (2) If it receives a RetryPeerConnections message, it retries peer connection establishment | ||
| // (3) After multiple retires, if the connection is still unsuccessful, the server starts excluding nodes from the round | ||
| // and sends a ConnectionFail message to those nodes | ||
| // (4) Upon receiving ConnectionFail, the client disconnects from the server | ||
| const msg = await Promise.race([ | ||
| waitMessage(this.server, type.StartWeightSharing), | ||
| waitMessage(this.server, type.RetryPeerConnections), | ||
| waitMessage(this.server, type.ConnectionFail), | ||
| ]) | ||
|
Comment on lines
+240
to
+244
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This creates two event listeners that are not resolved per round and accumulates throughout the rounds. I don't think it's a huge deal but can you add a comment/TODO to note this? |
||
|
|
||
| if (msg.type === type.StartWeightSharing){ | ||
| // Generate a promise that resolves when round training finishes | ||
| if (this.#roundFinishedPromise === undefined){ | ||
| this.#roundFinishedPromise = new Promise<void>((resolve) => { | ||
| this.#resolveRoundFinished = resolve | ||
| }) | ||
| } | ||
| break | ||
| } else if (msg.type === type.RetryPeerConnections){ | ||
| debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`) | ||
| // clear the communication round peer pool | ||
| await this.#pool?.shutdown() | ||
| this.#pool = new PeerPool(this.ownId) | ||
| // clear the connections | ||
| this.#connections = Map() | ||
| this.setAggregatorNodes(Set(this.ownId)) | ||
| continue | ||
| } else if (msg.type === type.ConnectionFail){ | ||
| debug(`[${shortenId(this.ownId)}] disconnect from the server`) | ||
| await this.disconnect() | ||
| throw new Error("Client disconnected after connection failure") | ||
| } | ||
| } | ||
| // Exchange weight updates with peers and return aggregated weights | ||
| return await this.exchangeWeightUpdates(weights) | ||
| const aggregatedWeight = await this.exchangeWeightUpdates(weights) | ||
|
Comment on lines
+254
to
+270
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this scenario ever covered in the unit tests? |
||
|
|
||
| return aggregatedWeight | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -178,8 +294,9 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| try { | ||
| debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); | ||
| const receivedMessage = await waitMessage(this.server, type.PeersForRound) | ||
|
|
||
| const peers = Set(receivedMessage.peers) | ||
| debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray()); | ||
|
|
||
| if (this.ownId !== undefined && peers.has(this.ownId)) { | ||
| throw new Error('received peer list contains our own id') | ||
|
|
@@ -198,7 +315,9 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| (conn) => this.receivePayloads(conn) | ||
| ) | ||
|
|
||
| debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS()); | ||
| // Signal server that all connections with other peers in the round are established | ||
| this.server.send({ type: type.ConnectionsReady }); | ||
| debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS()); | ||
| this.#connections = connections | ||
| } catch (e) { | ||
| debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e); | ||
|
|
@@ -303,4 +422,48 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| } | ||
| return await this.aggregationResult | ||
| } | ||
|
|
||
| /** | ||
| * Receive model from the model provider. | ||
| */ | ||
| private async receiveModel(providerConn: PeerConnection): Promise<WeightsContainer>{ | ||
| const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should make this timeout value a task parameter, depending on the model size it can make sense to increase the timeout |
||
|
|
||
| const decoded = serialization.weights.decode(message.model) | ||
| return decoded | ||
| } | ||
|
|
||
| /** | ||
| * Send the latest available model to a newly joining client. | ||
| * If the current training round is in progress, wait until the round finishes | ||
| * and receive the latest aggregated model. | ||
| */ | ||
| private async sendModel(newcomerConn: PeerConnection, roundFinishedPromise: Promise<void> | undefined): Promise<void> { | ||
| // wait until the round finishes to get the latest model | ||
| if (roundFinishedPromise !== undefined){ | ||
| await roundFinishedPromise | ||
| } | ||
|
|
||
| const model = this.modelWeightAccess?.getModelWeight() | ||
|
|
||
| if (model === undefined){ | ||
| debug("Failed to get the latest model from model provider client") | ||
| return | ||
| } | ||
| const encoded = await serialization.weights.encode(model) | ||
|
|
||
| const message: messages.SharedModel = { | ||
| type: type.SharedModel, | ||
| model: encoded | ||
| } | ||
| newcomerConn.send(message) | ||
| } | ||
|
|
||
| // Resolve the round finished promise and reset related state | ||
| override finishRound(): void{ | ||
| // Mark round as finished so that model synchronization can proceed | ||
| this.#resolveRoundFinished?.() | ||
| this.#roundFinishedPromise = undefined | ||
| this.#resolveRoundFinished = undefined | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling
getPeersadds theproviderNodeto the peer pool (this line) which is then never removed and unused in aggregation either.I don't think that's a big deal as I assume this model sync won't happen that often but worth adding a comment to help debug if it becomes problematic