diff --git a/node/node.go b/node/node.go index e8b93e2372..07faab89a5 100644 --- a/node/node.go +++ b/node/node.go @@ -301,7 +301,7 @@ func New(cfg *Config, version string, logLevel *log.Level) (*Node, error) { logger, ), cfg.MaxConcurrentCompilations, - int32(cfg.MaxCompilationQueue), + uint64(cfg.MaxCompilationQueue), ) if cfg.Sequencer { @@ -320,7 +320,7 @@ func New(cfg *Config, version string, logLevel *log.Level) (*Node, error) { FeeTokenAddresses: feeTokens, } nodeVM = vm.New(&chainInfo, false, logger) - throttledVM = NewThrottledVM(nodeVM, cfg.MaxVMs, int32(cfg.MaxVMQueue)) + throttledVM = NewThrottledVM(nodeVM, cfg.MaxVMs, uint64(cfg.MaxVMQueue)) mempool := mempool.New(database, chain, mempoolLimit, logger) executor := builder.NewExecutor(chain, nodeVM, logger, cfg.SeqDisableFees, false) builder := builder.New(chain, executor) @@ -380,7 +380,7 @@ func New(cfg *Config, version string, logLevel *log.Level) (*Node, error) { FeeTokenAddresses: feeTokens, } nodeVM = vm.New(&chainInfo, false, logger) - throttledVM = NewThrottledVM(nodeVM, cfg.MaxVMs, int32(cfg.MaxVMQueue)) + throttledVM = NewThrottledVM(nodeVM, cfg.MaxVMs, uint64(cfg.MaxVMQueue)) feederGatewayDataSource := sync.NewFeederGatewayDataSource(chain, adaptfeeder.New(client)) synchronizer = sync.New( diff --git a/node/throttled_compiler.go b/node/throttled_compiler.go index 8e6ba902f2..597cc089a1 100644 --- a/node/throttled_compiler.go +++ b/node/throttled_compiler.go @@ -5,20 +5,22 @@ import ( "github.com/NethermindEth/juno/starknet" "github.com/NethermindEth/juno/starknet/compiler" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" ) var _ compiler.Compiler = (*ThrottledCompiler)(nil) type ThrottledCompiler struct { - *utils.Throttler[compiler.Compiler] + *throttler.Throttler[compiler.Compiler] } func NewThrottledCompiler( - res compiler.Compiler, concurrencyBudget uint, maxQueueLen int32, + res compiler.Compiler, concurrencyBudget uint, maxQueueLen uint64, ) *ThrottledCompiler { return &ThrottledCompiler{ - Throttler: utils.NewThrottler(concurrencyBudget, &res).WithMaxQueueLen(maxQueueLen), + Throttler: throttler.NewThrottler( + concurrencyBudget, &res, throttler.WithMaxQueueLen(maxQueueLen), + ), } } @@ -26,7 +28,7 @@ func (tc *ThrottledCompiler) Compile( ctx context.Context, sierra *starknet.SierraClass, ) (*starknet.CasmClass, error) { var result *starknet.CasmClass - err := tc.Do(func(c *compiler.Compiler) error { + err := tc.Do(ctx, func(c *compiler.Compiler) error { var cErr error result, cErr = (*c).Compile(ctx, sierra) return cErr diff --git a/node/throttled_compiler_test.go b/node/throttled_compiler_test.go index 63c80fd3fd..a6975e412c 100644 --- a/node/throttled_compiler_test.go +++ b/node/throttled_compiler_test.go @@ -8,7 +8,7 @@ import ( "github.com/NethermindEth/juno/node" "github.com/NethermindEth/juno/starknet" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -52,7 +52,7 @@ func TestThrottledCompiler(t *testing.T) { // The queue is full, so the next request is rejected. _, err := throttled.Compile(t.Context(), &starknet.SierraClass{}) - require.ErrorIs(t, err, utils.ErrResourceBusy) + require.ErrorIs(t, err, throttler.ErrResourceBusy) // Release the four running/queued jobs and let them finish. for range 4 { diff --git a/node/throttled_vm.go b/node/throttled_vm.go index a0f856da07..71aba3dad6 100644 --- a/node/throttled_vm.go +++ b/node/throttled_vm.go @@ -1,21 +1,23 @@ package node import ( + "context" + "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" ) var _ vm.VM = (*ThrottledVM)(nil) type ThrottledVM struct { - *utils.Throttler[vm.VM] + *throttler.Throttler[vm.VM] } -func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen int32) *ThrottledVM { +func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen uint64) *ThrottledVM { return &ThrottledVM{ - Throttler: utils.NewThrottler(concurrenyBudget, &res).WithMaxQueueLen(maxQueueLen), + Throttler: throttler.NewThrottler(concurrenyBudget, &res, throttler.WithMaxQueueLen(maxQueueLen)), } } @@ -28,7 +30,8 @@ func (tvm *ThrottledVM) Call( errStack, returnStateDiff bool, ) (vm.CallResult, error) { ret := vm.CallResult{} - return ret, tvm.Do(func(vm *vm.VM) error { + // vm.VM carries no ctx; queued VM calls aren't cancellable yet. + return ret, tvm.Do(context.Background(), func(vm *vm.VM) error { var err error ret, err = (*vm).Call( callInfo, @@ -47,7 +50,8 @@ func (tvm *ThrottledVM) runExec( fn func(inner vm.VM) (vm.ExecutionResults, error), ) (vm.ExecutionResults, error) { var result vm.ExecutionResults - return result, tvm.Do(func(inner *vm.VM) error { + // vm.VM carries no ctx; queued VM calls aren't cancellable yet. + return result, tvm.Do(context.Background(), func(inner *vm.VM) error { var err error result, err = fn(*inner) return err diff --git a/rpc/v10/simulation.go b/rpc/v10/simulation.go index 655996f7b3..d4cb2b1a57 100644 --- a/rpc/v10/simulation.go +++ b/rpc/v10/simulation.go @@ -14,7 +14,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" ) @@ -319,7 +319,7 @@ func (h *Handler) prepareTransactions( network, ) if aErr != nil { - if errors.Is(aErr, utils.ErrResourceBusy) { + if errors.Is(aErr, throttler.ErrResourceBusy) { return nil, nil, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledCompilerErr) } return nil, nil, jsonrpc.Err(jsonrpc.InvalidParams, aErr.Error()) @@ -335,7 +335,7 @@ func (h *Handler) prepareTransactions( } func handleExecutionError(err error) *jsonrpc.Error { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr) } var txnExecutionError vm.TransactionExecutionError diff --git a/rpc/v10/simulation_pkg_test.go b/rpc/v10/simulation_pkg_test.go index fa842ff6b3..bd43a7d700 100644 --- a/rpc/v10/simulation_pkg_test.go +++ b/rpc/v10/simulation_pkg_test.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" "github.com/stretchr/testify/require" ) @@ -130,7 +130,7 @@ func TestHandleExecutionError(t *testing.T) { }{ { name: "Resource Busy Error", - err: utils.ErrResourceBusy, + err: throttler.ErrResourceBusy, jsonRPCError: rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr), }, { diff --git a/rpc/v10/trace.go b/rpc/v10/trace.go index 28da052320..af89cb4f7b 100644 --- a/rpc/v10/trace.go +++ b/rpc/v10/trace.go @@ -18,6 +18,7 @@ import ( "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/sync" "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" ) @@ -100,7 +101,7 @@ func (h *Handler) Call( false, ) if err != nil { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return nil, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr) } return nil, MakeContractError(json.RawMessage(err.Error())) @@ -192,7 +193,7 @@ func traceTransactionsWithState( httpHeader.Set(ExecutionStepsHeader, strconv.FormatUint(executionResult.NumSteps, 10)) if vmErr != nil { - if errors.Is(vmErr, utils.ErrResourceBusy) { + if errors.Is(vmErr, throttler.ErrResourceBusy) { return nil, nil, httpHeader, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr) } return nil, nil, httpHeader, rpccore.ErrUnexpectedError.CloneWithData(vmErr.Error()) diff --git a/rpc/v10/transaction.go b/rpc/v10/transaction.go index 7161f63ea1..6a36d44cdf 100644 --- a/rpc/v10/transaction.go +++ b/rpc/v10/transaction.go @@ -17,6 +17,7 @@ import ( "github.com/NethermindEth/juno/starknet" "github.com/NethermindEth/juno/starknet/compiler" "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "go.uber.org/zap" ) @@ -200,7 +201,7 @@ func (h *Handler) addToMempool( ctx, h.compiler, tx, h.bcReader.Network(), ) if err != nil { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return AddTxResponse{}, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledCompilerErr) } return AddTxResponse{}, rpccore.ErrInternal.CloneWithData(err.Error()) diff --git a/rpc/v8/simulation.go b/rpc/v8/simulation.go index 6f0605b740..5b30937975 100644 --- a/rpc/v8/simulation.go +++ b/rpc/v8/simulation.go @@ -14,7 +14,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -197,7 +197,7 @@ func (h *Handler) prepareTransactions( network, ) if aErr != nil { - if errors.Is(aErr, utils.ErrResourceBusy) { + if errors.Is(aErr, throttler.ErrResourceBusy) { return nil, nil, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledCompilerErr) } return nil, nil, jsonrpc.Err(jsonrpc.InvalidParams, aErr.Error()) @@ -213,7 +213,7 @@ func (h *Handler) prepareTransactions( } func handleExecutionError(err error) *jsonrpc.Error { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr) } var txnExecutionError vm.TransactionExecutionError diff --git a/rpc/v8/simulation_pkg_test.go b/rpc/v8/simulation_pkg_test.go index 9b71ab4370..c30ad9e58e 100644 --- a/rpc/v8/simulation_pkg_test.go +++ b/rpc/v8/simulation_pkg_test.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" "github.com/stretchr/testify/require" ) @@ -123,7 +123,7 @@ func TestHandleExecutionError(t *testing.T) { }{ { name: "Resource Busy Error", - err: utils.ErrResourceBusy, + err: throttler.ErrResourceBusy, jsonRPCError: rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr), }, { diff --git a/rpc/v8/trace.go b/rpc/v8/trace.go index ccc7f55725..d3b26d1bc5 100644 --- a/rpc/v8/trace.go +++ b/rpc/v8/trace.go @@ -16,6 +16,7 @@ import ( "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" ) @@ -239,7 +240,7 @@ func (h *Handler) traceBlockTransactionWithVM(block *core.Block) ( httpHeader.Set(ExecutionStepsHeader, strconv.FormatUint(executionResult.NumSteps, 10)) if err != nil { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return nil, httpHeader, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr) } // Since we are tracing an existing block, we know that there should be no errors during execution. If we encounter any, @@ -387,7 +388,7 @@ func (h *Handler) Call(funcCall *FunctionCall, id *BlockID) ([]*felt.Felt, *json false, ) if err != nil { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return nil, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr) } return nil, MakeContractError(json.RawMessage(err.Error())) diff --git a/rpc/v8/transaction.go b/rpc/v8/transaction.go index 90668dfc77..a99df814e7 100644 --- a/rpc/v8/transaction.go +++ b/rpc/v8/transaction.go @@ -21,6 +21,7 @@ import ( "github.com/NethermindEth/juno/starknet" "github.com/NethermindEth/juno/starknet/compiler" "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "go.uber.org/zap" ) @@ -664,7 +665,7 @@ func (h *Handler) addToMempool(ctx context.Context, tx *BroadcastedTransaction) ctx, h.compiler, tx, h.bcReader.Network(), ) if err != nil { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return AddTxResponse{}, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledCompilerErr) } return AddTxResponse{}, rpccore.ErrInternal.CloneWithData(err.Error()) diff --git a/rpc/v9/simulation.go b/rpc/v9/simulation.go index 61939af76c..bb95eee551 100644 --- a/rpc/v9/simulation.go +++ b/rpc/v9/simulation.go @@ -14,7 +14,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -207,7 +207,7 @@ func (h *Handler) prepareTransactions( network, ) if aErr != nil { - if errors.Is(aErr, utils.ErrResourceBusy) { + if errors.Is(aErr, throttler.ErrResourceBusy) { return nil, nil, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledCompilerErr) } return nil, nil, jsonrpc.Err(jsonrpc.InvalidParams, aErr.Error()) @@ -223,7 +223,7 @@ func (h *Handler) prepareTransactions( } func handleExecutionError(err error) *jsonrpc.Error { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr) } var txnExecutionError vm.TransactionExecutionError diff --git a/rpc/v9/simulation_pkg_test.go b/rpc/v9/simulation_pkg_test.go index 0a37ca9e78..2bfaf330a2 100644 --- a/rpc/v9/simulation_pkg_test.go +++ b/rpc/v9/simulation_pkg_test.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" "github.com/stretchr/testify/require" ) @@ -123,7 +123,7 @@ func TestHandleExecutionError(t *testing.T) { }{ { name: "Resource Busy Error", - err: utils.ErrResourceBusy, + err: throttler.ErrResourceBusy, jsonRPCError: rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr), }, { diff --git a/rpc/v9/trace.go b/rpc/v9/trace.go index 8e87c63cb8..c83d622f58 100644 --- a/rpc/v9/trace.go +++ b/rpc/v9/trace.go @@ -18,6 +18,7 @@ import ( "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/sync" "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "github.com/NethermindEth/juno/vm" ) @@ -158,7 +159,7 @@ func (h *Handler) Call(funcCall *FunctionCall, id *BlockID) ([]*felt.Felt, *json false, ) if err != nil { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return nil, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr) } return nil, MakeContractError(json.RawMessage(err.Error())) @@ -227,7 +228,7 @@ func traceTransactionsWithState( httpHeader.Set(ExecutionStepsHeader, strconv.FormatUint(executionResult.NumSteps, 10)) if vmErr != nil { - if errors.Is(vmErr, utils.ErrResourceBusy) { + if errors.Is(vmErr, throttler.ErrResourceBusy) { return nil, httpHeader, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledVMErr) } return nil, httpHeader, rpccore.ErrUnexpectedError.CloneWithData(vmErr.Error()) diff --git a/rpc/v9/transaction.go b/rpc/v9/transaction.go index ed4c7fdca5..2d321128f2 100644 --- a/rpc/v9/transaction.go +++ b/rpc/v9/transaction.go @@ -21,6 +21,7 @@ import ( "github.com/NethermindEth/juno/starknet" "github.com/NethermindEth/juno/starknet/compiler" "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/throttler" "go.uber.org/zap" ) @@ -723,7 +724,7 @@ func (h *Handler) addToMempool(ctx context.Context, tx *BroadcastedTransaction) ctx, h.compiler, tx, h.bcReader.Network(), ) if err != nil { - if errors.Is(err, utils.ErrResourceBusy) { + if errors.Is(err, throttler.ErrResourceBusy) { return AddTxResponse{}, rpccore.ErrInternal.CloneWithData(rpccore.ThrottledCompilerErr) } return AddTxResponse{}, rpccore.ErrInternal.CloneWithData(err.Error()) diff --git a/utils/throttler.go b/utils/throttler.go deleted file mode 100644 index 0d54c24e2c..0000000000 --- a/utils/throttler.go +++ /dev/null @@ -1,56 +0,0 @@ -package utils - -import ( - "errors" - "math" - "sync/atomic" -) - -var ErrResourceBusy = errors.New("resource busy, try again") - -type Throttler[T any] struct { - resource *T - sem chan struct{} - queue atomic.Int32 - - maxQueueLen int32 -} - -func NewThrottler[T any](concurrencyBudget uint, resource *T) *Throttler[T] { - return &Throttler[T]{ - resource: resource, - sem: make(chan struct{}, concurrencyBudget), - maxQueueLen: math.MaxInt32, - } -} - -// WithMaxQueueLen sets the maximum length the queue can grow to -func (t *Throttler[T]) WithMaxQueueLen(maxQueueLen int32) *Throttler[T] { - t.maxQueueLen = maxQueueLen - return t -} - -// Do lets caller acquire the resource within the context of a callback -func (t *Throttler[T]) Do(doer func(resource *T) error) error { - queueLen := t.queue.Add(1) - if queueLen > t.maxQueueLen { - t.queue.Add(-1) - return ErrResourceBusy - } - t.sem <- struct{}{} - defer func() { - <-t.sem - }() - t.queue.Add(-1) - return doer(t.resource) -} - -// QueueLen returns the number of Do calls that is blocked on the resource -func (t *Throttler[T]) QueueLen() int { - return int(t.queue.Load()) -} - -// JobsRunning returns the number of Do calls that are running at the moment -func (t *Throttler[T]) JobsRunning() int { - return len(t.sem) -} diff --git a/utils/throttler/throttler.go b/utils/throttler/throttler.go new file mode 100644 index 0000000000..c1021556ca --- /dev/null +++ b/utils/throttler/throttler.go @@ -0,0 +1,94 @@ +package throttler + +import ( + "context" + "errors" + "math" + "sync/atomic" +) + +var ErrResourceBusy = errors.New("resource busy, try again") + +// Throttler limits how many times an action is done concurrently +// and how many requests for these actions can be queued at max. +type Throttler[T any] struct { + resource *T + sem chan struct{} + + // currentRequests counts current active and queued requests + currentRequests atomic.Uint64 + // maxRequests is the total of possible requests (active + queued) + maxRequests uint64 +} + +type options struct { + maxQueueLen uint64 +} + +type Option func(*options) + +// WithMaxQueueLen sets the maximum length the queue can grow to. +func WithMaxQueueLen(maxQueueLen uint64) Option { + return func(o *options) { + o.maxQueueLen = maxQueueLen + } +} + +// NewThrottler returns a new throttler that will allow up to `maxConcurrentReqs` concurrent +// requests for resource `T`. See [throttler.Option] for other options. +func NewThrottler[T any](maxConcurrentReqs uint, resource *T, opts ...Option) *Throttler[T] { + o := options{ + maxQueueLen: 1024, + } + for _, opt := range opts { + opt(&o) + } + + // guard against overflow + maxRequests := o.maxQueueLen + uint64(maxConcurrentReqs) + if maxRequests < o.maxQueueLen { + maxRequests = math.MaxUint64 + } + + return &Throttler[T]{ + resource: resource, + sem: make(chan struct{}, maxConcurrentReqs), + + currentRequests: atomic.Uint64{}, + maxRequests: maxRequests, + } +} + +// Do lets caller acquire the resource within the context of a callback +func (t *Throttler[T]) Do(ctx context.Context, doer func(resource *T) error) error { + if err := ctx.Err(); err != nil { + return err // already cancelled, don't even enter the queue + } + + activeReqs := t.currentRequests.Add(1) + defer t.currentRequests.Add(^uint64(0)) // decrement by 1 + if activeReqs > t.maxRequests { + return ErrResourceBusy + } + + select { + case t.sem <- struct{}{}: + case <-ctx.Done(): + return ctx.Err() + } + + defer func() { + <-t.sem + }() + return doer(t.resource) +} + +// QueueLen returns the number of Do calls that is blocked on the resource +func (t *Throttler[T]) QueueLen() int { + return int(t.currentRequests.Load()) - len(t.sem) +} + +// JobsRunning returns the number of Do calls that are running at the moment +func (t *Throttler[T]) JobsRunning() int { + return len(t.sem) +} diff --git a/utils/throttler/throttler_test.go b/utils/throttler/throttler_test.go new file mode 100644 index 0000000000..851e5705bb --- /dev/null +++ b/utils/throttler/throttler_test.go @@ -0,0 +1,173 @@ +package throttler_test + +import ( + "context" + "errors" + "math" + "sync" + "sync/atomic" + "testing" + "testing/synctest" + + "github.com/NethermindEth/juno/utils/throttler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestThrottler(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + throttledRes := throttler.NewThrottler(2, new(int), throttler.WithMaxQueueLen(2)) + ctx := t.Context() + waitOn := make(chan struct{}) + + var runCount int64 + doer := func(ptr *int) error { + if ptr == nil { + return errors.New("nilptr") + } + <-waitOn + atomic.AddInt64(&runCount, 1) + return nil + } + + var wg sync.WaitGroup + // do spawns a Do call, then waits for the bubble to settle so the + // throttler's queue length is observed deterministically. + do := func() { + wg.Go(func() { + // assert, not require: runs off the test goroutine + assert.NoError(t, throttledRes.Do(ctx, doer)) + }) + synctest.Wait() // block until the spawned goroutine is durably blocked + } + + do() + assert.Equal(t, 0, throttledRes.QueueLen()) + do() + assert.Equal(t, 0, throttledRes.QueueLen()) + + do() // should be queued + assert.Equal(t, 1, throttledRes.QueueLen()) + do() // should be queued + assert.Equal(t, 2, throttledRes.QueueLen()) + + require.ErrorIs(t, throttledRes.Do(ctx, doer), throttler.ErrResourceBusy) + + waitOn <- struct{}{} // release one of the slots + synctest.Wait() + assert.Equal(t, 1, throttledRes.QueueLen()) + waitOn <- struct{}{} // release another slot, queue should be empty + synctest.Wait() + assert.Equal(t, 0, throttledRes.QueueLen()) + + // release the jobs waiting + waitOn <- struct{}{} + waitOn <- struct{}{} + wg.Wait() + assert.Equal(t, int64(4), runCount) + }) +} + +func TestThrottlerContextCancellation(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Budget of 1 and a generous queue, so the second call is queued (waiting + // for the slot) rather than rejected with ErrResourceBusy. + throttledRes := throttler.NewThrottler(1, new(int), throttler.WithMaxQueueLen(10)) + waitOn := make(chan struct{}) + + //nolint: unparam // defined this way to satisfy a function signature + blockingDoer := func(*int) error { + <-waitOn + return nil + } + + // Occupy the only slot with a job that blocks until released. + var wg sync.WaitGroup + wg.Go(func() { + assert.NoError(t, throttledRes.Do(t.Context(), blockingDoer)) + }) + synctest.Wait() + assert.Equal(t, 1, throttledRes.JobsRunning()) + + // The second call blocks waiting for the slot; cancel its context while queued. + ctx, cancel := context.WithCancel(t.Context()) + errCh := make(chan error, 1) + go func() { + errCh <- throttledRes.Do(ctx, blockingDoer) + }() + synctest.Wait() + assert.Equal(t, 1, throttledRes.QueueLen()) + + cancel() + require.ErrorIs(t, <-errCh, context.Canceled) + synctest.Wait() + assert.Equal(t, 0, throttledRes.QueueLen()) // queue is decremented on cancellation + + // Release the running job. + waitOn <- struct{}{} + wg.Wait() + }) +} + +func TestThrottlerAlreadyCancelledContext(t *testing.T) { + throttledRes := throttler.NewThrottler(1, new(int)) + ctx, cancel := context.WithCancel(t.Context()) + cancel() // cancelled before Do is called + + ran := false + err := throttledRes.Do(ctx, func(*int) error { + ran = true + return nil + }) + require.ErrorIs(t, err, context.Canceled) + assert.False(t, ran) // the doer is never invoked + assert.Equal(t, 0, throttledRes.JobsRunning()) + assert.Equal(t, 0, throttledRes.QueueLen()) +} + +func TestThrottlerZeroQueue(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Budget of 2 with no queue: up to 2 calls run concurrently, anything + // beyond is rejected immediately rather than waiting. + throttledRes := throttler.NewThrottler(2, new(int), throttler.WithMaxQueueLen(0)) + waitOn := make(chan struct{}) + + //nolint: unparam // defined this way to satisfy a function signature + blockingDoer := func(*int) error { + <-waitOn + return nil + } + + var wg sync.WaitGroup + run := func() { + wg.Go(func() { + assert.NoError(t, throttledRes.Do(t.Context(), blockingDoer)) + }) + synctest.Wait() + } + + // Both calls fit within the concurrency budget and run despite the 0 queue. + run() + assert.Equal(t, 1, throttledRes.JobsRunning()) + run() + assert.Equal(t, 2, throttledRes.JobsRunning()) + assert.Equal(t, 0, throttledRes.QueueLen()) + + // Budget is full and the queue allows nothing, so the next call is rejected. + require.ErrorIs(t, throttledRes.Do(t.Context(), blockingDoer), throttler.ErrResourceBusy) + + // Release the running jobs. + waitOn <- struct{}{} + waitOn <- struct{}{} + wg.Wait() + }) +} + +func TestThrottlerMaxQueueLenOverflow(t *testing.T) { + // maxQueueLen + budget must not overflow: MaxUint64 + 1 wraps to 0, which + // without the guard would reject every call. The guard treats an overflowing + // total as unbounded, so a call within the budget still runs. + throttledRes := throttler.NewThrottler(1, new(int), throttler.WithMaxQueueLen(math.MaxUint64)) + require.NoError(t, throttledRes.Do(t.Context(), func(*int) error { return nil })) + assert.Equal(t, 0, throttledRes.QueueLen()) +} diff --git a/utils/throttler_test.go b/utils/throttler_test.go deleted file mode 100644 index efeb9c52a2..0000000000 --- a/utils/throttler_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package utils_test - -import ( - "errors" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/NethermindEth/juno/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestThrottler(t *testing.T) { - throttledRes := utils.NewThrottler(2, new(int)).WithMaxQueueLen(2) - waitOn := make(chan struct{}) - - var runCount int64 - doer := func(ptr *int) error { - if ptr == nil { - return errors.New("nilptr") - } - <-waitOn - atomic.AddInt64(&runCount, 1) - return nil - } - - var wg sync.WaitGroup - do := func() { - wg.Add(1) - go func() { - defer wg.Done() - require.NoError(t, throttledRes.Do(doer)) - }() - time.Sleep(time.Millisecond) - } - - do() - assert.Equal(t, 0, throttledRes.QueueLen()) - do() - assert.Equal(t, 0, throttledRes.QueueLen()) - - do() // should be queued - assert.Equal(t, 1, throttledRes.QueueLen()) - do() // should be queued - assert.Equal(t, 2, throttledRes.QueueLen()) - - require.ErrorIs(t, throttledRes.Do(doer), utils.ErrResourceBusy) - - waitOn <- struct{}{} // release one of the slots - time.Sleep(time.Millisecond) - assert.Equal(t, 1, throttledRes.QueueLen()) - waitOn <- struct{}{} // release another slot, qeueue should be empty - time.Sleep(time.Millisecond) - assert.Equal(t, 0, throttledRes.QueueLen()) - - // release the jobs waiting - waitOn <- struct{}{} - waitOn <- struct{}{} - wg.Wait() - assert.Equal(t, int64(4), runCount) -}