Skip to content
13 changes: 13 additions & 0 deletions discojs/src/aggregator/aggregator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,19 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon
this._nodes = this._nodes.delete(nodeId)
}

/**
* Dispose the contributions to clean tensor memory
*/
dispose(): void {
this.contributions.forEach((roundContributions) => {
roundContributions.forEach((contribution) => {
contribution.dispose()
})
})

this.contributions = Map()
}

/**
* Overwrites the current set of active nodes with the given one. A node represents
* an active neighbor peer/client within the network, whom we are communicating with
Expand Down
18 changes: 18 additions & 0 deletions discojs/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { EventConnection } from './event_connection.js'
import type { Aggregator } from '../aggregator/index.js'
import { EventEmitter } from '../utils/event_emitter.js'
import { type } from "./messages.js";
import { ModelWeightAccess } from "../training/disco.js";

const debug = createDebug("discojs:client");

Expand All @@ -24,6 +25,7 @@ const debug = createDebug("discojs:client");
export abstract class Client<N extends Network> extends EventEmitter<{
status: RoundStatus;
participants: number;
modelSynced: WeightsContainer | undefined;
}> {
// Own ID provided by the network's server.
protected _ownId?: NodeID
Expand All @@ -38,6 +40,9 @@ export abstract class Client<N extends Network> extends EventEmitter<{
*/
protected promiseForMoreParticipants: Promise<void> | undefined = undefined;

// Interface to access trainer's model weights
protected modelWeightAccess?: ModelWeightAccess;

/**
* When the server notifies the client that they can resume training
* after waiting for more participants, we want to be able to display what
Expand All @@ -56,6 +61,15 @@ export abstract class Client<N extends Network> extends EventEmitter<{
) {
super()
}

/**
* Used for decentralized learning.
* Set the interface used by client to access to trainer's model weights.
* Disco object provides this access.
*/
setModelWeightAccess(modelWeightAccess: ModelWeightAccess){
this.modelWeightAccess = modelWeightAccess
}

/**
* Communication callback called at the beginning of every training round.
Expand Down Expand Up @@ -193,6 +207,10 @@ export abstract class Client<N extends Network> extends EventEmitter<{
return await serialization.model.decode(encoded)
}

public finishRound(): void{
// DecentralizedClient override the method to clean up round state
}

/**
* Number of contributors to a collaborative session
* If decentralized, it should be the number of peers
Expand Down
181 changes: 172 additions & 9 deletions discojs/src/client/decentralized/decentralized_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import type { DataType, Model, WeightsContainer } from "../../index.js";
import { serialization } from "../../index.js";
import { Client, shortenId } from '../client.js'
import { type NodeID } from '../index.js'
import { type, type ClientConnected } from '../messages.js'
import { type, type ClientConnected, NarrowMessage } from '../messages.js'
import { timeout } from '../utils.js'
import { WebSocketServer, waitMessage, type PeerConnection, waitMessageWithTimeout } from '../event_connection.js'
import { PeerPool } from './peer_pool.js'
Expand All @@ -26,6 +26,14 @@ export class DecentralizedClient extends Client<"decentralized"> {
#pool?: PeerPool
#connections?: Map<NodeID, PeerConnection>

// Flag if this model requires model synchronization
#modelSyncNeeded?: boolean

// Check if the training round is in progress
// Used to get the latest model for model synchronization
#roundFinishedPromise?: Promise<void>
#resolveRoundFinished?: () => void // contains resolver

// Used to handle timeouts and promise resolving after calling disconnect
private get isDisconnected() : boolean {
return this._server === undefined
Expand All @@ -36,6 +44,26 @@ export class DecentralizedClient extends Client<"decentralized"> {
// Emits the `participants` event
this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size
}

// Used by model provider peer during model syncing
private async handleSignalNewPeer(event: NarrowMessage<type.SignalNewPeer>): Promise<void> {
if (this.#pool === undefined){
throw new Error('received signal about new peer but peer pool is undefined')
}
const roundFinishedPromise = this.#roundFinishedPromise
const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{})

const newcomerConn = syncConnection.get(event.newNode)

if (newcomerConn === undefined){
// if connection with newly joining client fails, print debug message
// and return
debug(`Cannot connect to newly joined client [${event.newNode}]`)
return
}

await this.sendModel(newcomerConn, roundFinishedPromise)
}

/**
* Public method called by disco.ts when starting training. This method sends
Expand Down Expand Up @@ -69,6 +97,13 @@ export class DecentralizedClient extends Client<"decentralized"> {
this.#pool.signal(event.peer, event.signal)
})

// Listen if the client is selected as a model provider node for a newly joining client.
// Upon receiving the signal, this client establishes a connection with the newcomer
// and sends the latest model weights.
this.server.on(type.SignalNewPeer, (event) => {
void this.handleSignalNewPeer(event)
})

// c.f. setupServerCallbacks doc for explanation
let receivedEnoughParticipants = false
this.setupServerCallbacks(() => receivedEnoughParticipants = true)
Expand All @@ -79,8 +114,9 @@ export class DecentralizedClient extends Client<"decentralized"> {
this.server.send(msg)

const { id, waitForMoreParticipants,
nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo)

nbOfParticipants, joinedMidTraining } = await waitMessage(this.server, type.NewDecentralizedNodeInfo)

this.#modelSyncNeeded = joinedMidTraining
this.nbOfParticipants = nbOfParticipants


Expand Down Expand Up @@ -129,14 +165,50 @@ export class DecentralizedClient extends Client<"decentralized"> {
* When connected, one peer creates a promise for every other peer's weight update
* and waits for it to resolve.
*
* If a client joined the training after the first round,
* model syncing happens first to get the latest model.
*/
override async onRoundBeginCommunication(): Promise<void> {
if (this.#modelSyncNeeded) {
// 1. If model sync is needed, send server a request
this.server.send({ type: type.ModelSyncRequest })

// 2. Get the provider information from the server
const providerInfo = await waitMessageWithTimeout(this.server, type.SignalModelProvider, 30_000, "Timeout while waiting for the latest model provider")

if (this.#pool === undefined) {
throw new Error('peer pool is undefined, make sure to call `client.connect()` first')
}

// 3. Connect with model provider client and get the latest model
const syncConnection = await this.#pool.getPeers(
Set([providerInfo.providerNode]),
this.server,
()=>{}
)
Comment on lines +184 to +188

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling getPeers adds the providerNode to the peer pool (this line) which is then never removed and unused in aggregation either.
I don't think that's a big deal as I assume this model sync won't happen that often but worth adding a comment to help debug if it becomes problematic

const providerConn = syncConnection.get(providerInfo.providerNode)

if (providerConn === undefined){
throw new Error("The latest model provider is not connected")
}

const latestModel = await this.receiveModel(providerConn)
this.modelWeightAccess?.setModelWeight(latestModel)

this.emit("modelSynced", this.modelWeightAccess?.getModelWeight())
this.#modelSyncNeeded = false
}

// Notify the server we want to join the next round so that the server
// waits for us to be ready before sending the list of peers for the round
this.server.send({ type: type.JoinRound })
// Store the promise for the current round's aggregation result.
// We will await for it to resolve at the end of the round when exchanging weight updates.
this.aggregationResult = this.aggregator.getPromiseForAggregation()

// Do not proceed to local training when minNbOfParticipants condition is not satisfied
await this.waitForParticipantsIfNeeded()

this.saveAndEmit("local training")
return Promise.resolve()
}
Expand All @@ -149,11 +221,55 @@ export class DecentralizedClient extends Client<"decentralized"> {
// Once enough new participants join we can display the previous status again
this.saveAndEmit("connecting to peers")
// First we check if we are waiting for more participants before sending our weight update
await this.waitForParticipantsIfNeeded()
// Create peer-to-peer connections with all peers for the round
await this.establishPeerConnections()

while(true){
// Wait until enough participants are available before continuing the round
// Checks minNbOfParticipants requirement for
// when participants disconnect when connection error happens continuously
await this.waitForParticipantsIfNeeded()

// Create peer-to-peer connections with all peers for the round
await this.establishPeerConnections()

// Wait for connection related messages from the server before exchanging weight updates
// (1) If the client receives a StartWeightSharing message, it proceeds to weight update exchange
// (2) If it receives a RetryPeerConnections message, it retries peer connection establishment
// (3) After multiple retires, if the connection is still unsuccessful, the server starts excluding nodes from the round
// and sends a ConnectionFail message to those nodes
// (4) Upon receiving ConnectionFail, the client disconnects from the server
const msg = await Promise.race([
waitMessage(this.server, type.StartWeightSharing),
waitMessage(this.server, type.RetryPeerConnections),
waitMessage(this.server, type.ConnectionFail),
])
Comment on lines +240 to +244

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates two event listeners that are not resolved per round and accumulates throughout the rounds. I don't think it's a huge deal but can you add a comment/TODO to note this?


if (msg.type === type.StartWeightSharing){
// Generate a promise that resolves when round training finishes
if (this.#roundFinishedPromise === undefined){
this.#roundFinishedPromise = new Promise<void>((resolve) => {
this.#resolveRoundFinished = resolve
})
}
break
} else if (msg.type === type.RetryPeerConnections){
debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`)
// clear the communication round peer pool
await this.#pool?.shutdown()
this.#pool = new PeerPool(this.ownId)
// clear the connections
this.#connections = Map()
this.setAggregatorNodes(Set(this.ownId))
continue
} else if (msg.type === type.ConnectionFail){
debug(`[${shortenId(this.ownId)}] disconnect from the server`)
await this.disconnect()
throw new Error("Client disconnected after connection failure")
}
}
// Exchange weight updates with peers and return aggregated weights
return await this.exchangeWeightUpdates(weights)
const aggregatedWeight = await this.exchangeWeightUpdates(weights)
Comment on lines +254 to +270

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this scenario ever covered in the unit tests?


return aggregatedWeight
}

/**
Expand All @@ -178,8 +294,9 @@ export class DecentralizedClient extends Client<"decentralized"> {
try {
debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`);
const receivedMessage = await waitMessage(this.server, type.PeersForRound)

const peers = Set(receivedMessage.peers)
debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray());

if (this.ownId !== undefined && peers.has(this.ownId)) {
throw new Error('received peer list contains our own id')
Expand All @@ -198,7 +315,9 @@ export class DecentralizedClient extends Client<"decentralized"> {
(conn) => this.receivePayloads(conn)
)

debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
// Signal server that all connections with other peers in the round are established
this.server.send({ type: type.ConnectionsReady });
debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS());
this.#connections = connections
} catch (e) {
debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
Expand Down Expand Up @@ -303,4 +422,48 @@ export class DecentralizedClient extends Client<"decentralized"> {
}
return await this.aggregationResult
}

/**
* Receive model from the model provider.
*/
private async receiveModel(providerConn: PeerConnection): Promise<WeightsContainer>{
const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make this timeout value a task parameter, depending on the model size it can make sense to increase the timeout


const decoded = serialization.weights.decode(message.model)
return decoded
}

/**
* Send the latest available model to a newly joining client.
* If the current training round is in progress, wait until the round finishes
* and receive the latest aggregated model.
*/
private async sendModel(newcomerConn: PeerConnection, roundFinishedPromise: Promise<void> | undefined): Promise<void> {
// wait until the round finishes to get the latest model
if (roundFinishedPromise !== undefined){
await roundFinishedPromise
}

const model = this.modelWeightAccess?.getModelWeight()

if (model === undefined){
debug("Failed to get the latest model from model provider client")
return
}
const encoded = await serialization.weights.encode(model)

const message: messages.SharedModel = {
type: type.SharedModel,
model: encoded
}
newcomerConn.send(message)
}

// Resolve the round finished promise and reset related state
override finishRound(): void{
// Mark round as finished so that model synchronization can proceed
this.#resolveRoundFinished?.()
this.#roundFinishedPromise = undefined
this.#resolveRoundFinished = undefined
}
}
Loading